基于单元重要度的条件计算方法

文档序号:8711 发布日期:2021-09-17 浏览:28次 英文

基于单元重要度的条件计算方法

技术领域

本发明涉及一种基于单元重要度的条件计算方法。

背景技术

目前,深度学习模型压缩主要包括裁剪、量化、知识蒸馏等。其中,裁剪按照粒度区分可分为神经元级裁剪、滤波器级裁剪、甚至残差单元级裁剪等,考虑实际应用场景中通用处理器的实际推理加速效果,通常采用的是滤波器级或残差单元级裁剪。常见的裁剪方案中通常设计一个滤波器或一个残差单元重要性评估指标,然后衡量每一个裁剪候选单元的重要性,并裁去重要度较低的直到模型的计算复杂度符合要求。

条件计算是一种较新的深度学习模型压缩手段,它利用了不同滤波器或不同残差单元提取的特征互不相同,以及不同输入图像拥有不同特征的特点,个性化地根据输入图像的不同决策出恰当的计算路径。现有的条件计算方法主要为残差单元级粒度的条件计算,通常通过强化学习训练一个小型门控网络用于根据输入或中间特征图预测各个残差单元的开闭。

但现有的条件计算方法大多采用强化学习,根据分类的交叉熵损失与裁剪率构建强化学习的reward,并将该reward返回给所有门控的输出进行训练。这使得门控网络的搜索空间非常大,在数据集容量有限的情况下很难实现良好的动态裁剪。

发明内容

本发明提供了一种基于单元重要度的条件计算方法,采用如下的技术方案:

一种基于单元重要度的条件计算方法,包含以下步骤:

S1:预先训练主干残差网络M,主干残差网络M包含n个残差单元;

S2:为预训练好的主干残差网络M构建门控网络G;

S3:计算主干残差网络M中每个残差单元对每一张输入图像的重要度;

S4:将输入图像及其对应的各残差单元的重要度组成为输入-标签对,构建数据集,固定主干残差网络M,通过数据集训练门控网络G;

S5:在训练好门控网络G后固定门控网络G,对主干残差网络M进行微调以适应动态裁剪;

S6:重复步骤S3-S5直到模型的裁剪率和精度满足预设条件。

进一步地,在步骤S3中计算主干残差网络M中每个残差单元对每一张输入图像的重要度的具体方法为通过下述公式进行计算:

imp(x,i)=loss(M-Block[i],x)-loss(M,x)

其中,x为输入图像,M-Block[i]为M中第i个残差单元被裁去时的剩余n-1个残差单元构成的子网络,function为给定的当前任务的目标函数,imp(x,i)为M中第i个残差单元对输入x的重要度。

进一步地,在步骤S4中,将重要度标注作为reward,门控网络G的输出G(x)作为各个门控的预测值,经过Sigmoid函数将门控预测值转化为开启概率后,使用类强化学习的算法对门控网络G进行训练。

进一步地,步骤S4中的目标函数通过下述公式进行计算,

其中,G(x)为各个门控的预测值,训练采用梯度上升以最大化目标函数。

进一步地,在步骤S5中对主干残差网络M进行微调时,使得每个输入图像都只经过所有n个残差单元的特定子集,对于某个输入图像,主干残差网络M的微调只对特定子集中的残差单元进行。

进一步地,步骤S2中构建的门控网络为ResNet8卷积神经网络,或以LSTM循环神经网络为主体的神经网络,或者是n个独立的MLP,每个MLP对应于一个残差单元。

本发明的有益之处在于所提供的基于单元重要度的条件计算方法,首先预先训练主干残差网络M,然后为预训练的主干残差网络M构建门控网络G用于预测所有主干残差网络M中残差单元的重要度与开闭。为了训练门控网络M,计算训练集上主干残差网络M中每个残差单元对每一张输入图像的重要度,并以此构建数据集用于训练门控网络G,使门控网络G能够根据输入图像与中间特征图,预测出不同残差单元的重要度。从而能够在推理阶段动态地对不同输入裁去重要度低,或者对当前输入无效甚至有害的残差单元以实现模型裁剪与精度提升。

附图说明

图1是本发明的基于单元重要度的条件计算方法的示意图。

具体实施方式

以下结合附图和具体实施例对本发明作具体的介绍。

如图1所示为本发明一种基于单元重要度的条件计算方法,主要包含以下步骤:步骤S1:预先训练主干残差网络M,主干残差网络M包含n个残差单元。步骤S2:为预训练好的主干残差网络M构建门控网络G。门控网络G用于控制主干残差网络M中n个残差单元的开闭。若残差单元开启,则前向推理时该残差单元被正常计算;若残差单元关闭,则前向推理时该残差单元中只有短连接被经过,残差单元被裁去而不需要做任何计算。步骤S3:选取若干输入图像,计算主干残差网络M中每个残差单元对每一张输入图像的重要度。步骤S4:将输入图像及其对应的各残差单元的重要度组成为输入-标签对,构建数据集,固定主干残差网络M,通过数据集训练门控网络G。步骤S5:在训练好门控网络G后固定门控网络G,对主干残差网络M进行微调以适应动态裁剪。步骤S6:重复步骤S3-S5直到模型的裁剪率和精度满足预设条件。通过上述步骤,首先预先训练主干残差网络M,然后为预训练的主干残差网络M构建门控网络G用于预测所有主干残差网络M中残差单元的重要度与开闭。为了训练门控网络M,计算训练集上主干残差网络M中每个残差单元对每一张输入图像的重要度,并以此构建数据集用于训练门控网络G,使门控网络G能够根据输入图像与中间特征图,预测出不同残差单元的重要度。从而能够在推理阶段动态地对不同输入裁去重要度低,或者对当前输入无效甚至有害的残差单元以实现模型裁剪与精度提升。

作为一种优选的实施方式,在步骤S3中,计算主干残差网络M中每个残差单元对每一张输入图像的重要度的具体方法为通过下述公式进行计算:

imp(x,i)=loss(M-Block[i],x)-loss(M,x)

其中,x为输入图像,M-Block[i]为M中第i个残差单元被裁去时的剩余n-1个残差单元构成的子网络,loss为给定的当前任务的损失函数,imp(x,i)为M中第i个残差单元对输入x的重要度。

作为一种优选的实施方式,在步骤S4中,将重要度标注作为reward,门控网络G的输出G(x)作为各个门控的预测值,经过Sigmoid函数将门控预测值转化为开启概率后,使用类强化学习的算法对门控网络G进行训练。

作为一种优选的实施方式,步骤S4中的损失函数通过下述公式进行计算,

其中,G(x)为各个门控的预测值,训练采用梯度上升以最大化目标函数。

作为一种优选的实施方式,在步骤S5中对主干残差网络M进行微调时,使得每个输入图像都只经过所有n个残差单元的特定子集,对于某个输入图像,主干残差网络M的微调只对特定子集中的残差单元进行。

具体而言,由于残差单元粒度的裁剪破坏了主干残差网络M预训练过程中BN层统计的数据分布信息,包括running_mean、running_var等。在正式适用门控网络G进行动态裁剪前我们还需要固定门控网络G,并在门控网络G的指导下进行动态裁剪,使得每个输入图像x都只经过所有n个残差单元的特定子集。譬如对于输入x0,在门控网络G的指导下我们裁掉了第3、第6个残差单元,此时x0对应的需要通过的残差单元子集为U={Block[1]、Block[2]、Block[4]、Block[5]、Block[7]、…、Block[n]},且在整个步骤S5这一步的微调环节中,图像x0都只会利用U中的残差单元进行推理,对于图像x0,主干残差网络的微调也只会对U中的残差单元进行。

作为一种优选的实施方式,步骤S2中构建的门控网络为卷积神经网络。卷积神经网络为ResNet8。门控网络G独立于主干残差网络M,直接接收输入图像为网络输入,并在全连接层输出所有门控的预测结果。

使用卷积神经网络型的门控网络能够使我们在主干残差网络运作之前就能一次性获得所有门控的预测结果,便于我们事先进行单元裁剪的决策,同时门控网络的开销不会随主干网络容量的增大而变大。

当采用卷积神经网络型的门控网络时,由于能够事先获得所有门控的预测结果,因此可以直接使用贪心法:找出重要度最低的一个或若干个残差单元裁去。也可以采用阈值法:设置阈值α,裁去所有-G(x)>α的单元并保留-G(x)<α;也可首先计算Softmax(-G(x))后设定阈值α,裁去Softmax(-G(x))>α的单元并保留Softmax(-G(x))<α的单元。

作为一种可选的实施方式,作为一种优选的实施方式,步骤S2中构建的门控网络为以L循环神经网络为主体的神经网络。作为一种优选的实施方式,循环神经网络为LSTM。使用LSTM等循环神经网络作为门控网络,将主干残差网络中每个残差单元的输入特征图组成序列,降维后输入门控网络,门控网络逐个对序列中每个残差单元对应的门控进行预测。

使用循环神经网络型的门控网络能够利用浅层所有残差单元的序列信息协同进行下一个残差单元门控的预测。

作为另一种可选的实施方式,步骤S2中构建的门控网络为n个独立的MLP(Multilayer Perceptron,多层感知器),每个MLP对应于一个残差单元。使用MLP型的门控网络,每个单元分配独立的门控单元使得门控网络的训练更加容易且稳定。

当采用循环神经网络型或MLP型门控网络时,由于无法事先获得所有门控的预测结果,需要在主干残差网络M前向推理的过程中同时进行动态裁剪的决策,只能采用阈值法。在精度敏感且计算开销限制较低的场合下,也可以首先对主干残差网络M进行一次前向推理用于收集所有门控的预测结果后,使用贪心法指导动态裁剪并重新进行一次主干残差网络M的前向推理。

以上显示和描述了本发明的基本原理、主要特征和优点。本行业的技术人员应该了解,上述实施例不以任何形式限制本发明,凡采用等同替换或等效变换的方式所获得的技术方案,均落在本发明的保护范围内。

完整详细技术资料下载
上一篇:石墨接头机器人自动装卡簧、装栓机
下一篇:一种深度卷积神经网络加速方法、模块、系统及存储介质

网友询问留言

已有0条留言

还没有人留言评论。精彩留言会获得点赞!

精彩留言,会给你点赞!