基于群表示特征的半监督学习图像分类方法
技术领域
本发明属于计算机视觉中的半监督分类领域,提供了一种基于群表示特征的半监督学习分类方法。
背景技术
深度学习模型已经成为计算机视觉应用的标准模型。它们的成功很大一部分取决于大型标注数据集的存在,比如,ImageNet,COCO等数据集提供了丰富的自然场景图片样本。根据经验来看,在较大的数据集上训练一般会得到性能更好地深度模型,深度学习通常通过有监督学习实现强大的性能,这就需要使用带有标签的数据。然而,对于有些任务来说,收集带标签数据是困难的,在进行人工标注的时候可能因为标志者的主观因素导致标注错误,或者标注数据需要专家知识,例如:医疗数据集,这会带来很大的成本消耗。相比之下,在大部分任务中,获取无标签数据是一件相对轻松的事情。
半监督学习就是一种不需要大量有标签数据就可以在大规模数据集上进行训练的有效方法,它通过允许模型学习未标记数据来大大减少对有标记数据的需求。许多半监督学习方法通常根据未标记数据对目标函数添加损失项,鼓励模型更好地归纳学习未标记数据的特征分布。目前,众多半监督学习方法中,一致性约束和伪标签是两种最常用的方法,同时也存在将两者相结合的方法。伪标签方法将模型对无标签数据的预测作为无标签数据的标签进行训练,而一致性约束方法将模型对无标签数据的预测分布作为标签进行训练。两种方法实现策略不同,但是从含义上都是对无标签数据生成人工标签进行训练。
在本项工作中,我们将沿用现有的SOTA方法的趋势,并结合基于群表示特征的一致性约束方案,构建更加有效的半监督学习分类方法。
发明内容
在最近的半监督分类方法中,常见的是通过对大量未标记数据使用一致性约束进行训练,要求模型预测对于输入样本的噪声具有不变性。我们使用协方差矩阵在流形空间中表示样本空间,来增强一致性训练性能。我们发现这样的方法和伪标签方法结合,可以得到更加有效的半监督学习分类模型。
基于群表示特征的半监督学习图像分类方法,步骤如下:
步骤一:对图像数据集进行预处理;
使用部分有标签图像,其余图像不使用标签;对每一张图片进行两种不同的数据增强方式,形成同一张图像的两种观察视角图像:
(1)进行随机水平翻转,裁剪图像成32×32的尺寸并最后进行归一化处理,得到的图像称为弱增强图像;
(2)进行随机水平翻转,裁剪图像成32×32的尺寸,随机图像增强策略并最后进行归一化处理,得到的图像称为强增强图像;
有标签图像只使用弱增强图像,无标签数据使用弱增强图像和强增强图像。
进一步地,有标签图像占所有图像数量百分比为小于5%。
进一步地,所述的随机图像增强策略包括对比度增强、亮度增强、色度增强、锐度增强、最大图像对比度、均衡图像直方图、将颜色通道上的变量bits置0、随机旋转、随机错切和反转像素点。在进行随机图像增强策略的过程中,随机采用上述的随机图像增强策略进行图像变换,并随机设置操作参数。
步骤二:构建两个相同的WiderResNet分类网络模型;
分类网络模型的宽度和深度参数分别为10、28,其中一个分类网络模型作为基础模型Pbase,另一个分类网络模型作为经验模型Pexp。使用不带Nesterov动量(牛顿动量)的SGD(Stochastic gradient descent,随机梯度下降)优化方法进行参数优化,初始学习率为1e-2,权重衰减参数为1e-3,并使用Cosine学习率衰减策略更新学习率。
步骤三:计算有标签图像在基础模型上的分类误差;
将有标签的弱增强图像IL_w输入基础模型Pbase得到对于输入图像的类别预测分布qL_w=Pbase(IL_w),根据标签Pb,使用交叉熵损失H计算得到有标签数据的分类损失函数Lsup:
其中,B表示每一个batch的大小。
步骤四:基础模型使用SGD优化器进行参数优化
对于经验模型,根据基础模型Pbase的模型参数θt,使用基于Momentum动量的加权平均方法更新经验模型Pexp的参数θ′t,下标t表示第t个迭代时,α为超参数:
θ′t=αθ′t-1+(1-α)θt (2)
步骤四中的经验模型构建方法。我们可以将一个数据特征生成算法表示成一个从到上的同构映射f可以线性的或者是非线性的。因此所有的f也可以构成一个拓扑流形。进一步我们可以认为f是连续且可微的,因此所有的f又构成一个微分流形,同时也是一个李群。
对于上的协方差矩阵群∑,由于f在无标签数据数据的特征上的作用,产生数据样本在新的特征空间上的表示因此在与f同构的映射的作用下,产生协方差矩阵群∑′。由于f一般是非线性映射,群∑和群∑′一般也不是线性同构,或者说不是同一个群的两个不同的线性表示,因此根据特征标理论,Λ∑≠Λ∑′。
要直接求解映射f是一件很困难的事情,但是我们可以使用神经网络来将f拟合出来。所以现在问题变成:构建一个映射f,使得我们对其输入无标签数据的时候可以得到一个具有辨别性与普遍性的特征图。在半监督学习中,我们会同时对有标签数据和无标签数据进行学习,学习有标签数据是为了学习到更加准确的特征提取方法,作为无标签数据特征提取的基础。所以我们会在基础模型的基础上得到经验模型,我们选择对基础模型参数进行动量加权平均来更新经验模型参数。之所以不直接用基础模型的最终参数,因为我们一般只会使用比较少的有标签数据,基础模型在有标签数据上会很快收敛并且达到过拟合状态,如果直接使用基础模型最终参数会影响经验模型的泛化能力,所以们在训练步骤上平均基础模型权重会得到更准确的经验模型。
步骤五:利用步骤四更新过后的经验模型,计算无标签数据部分的一致性约束损失Lconsisteny和伪标签损失Lpseudo;
将无标签弱增强图像IuL_w输入经验模型Pexp得到经验特征FuL_w和对于输入图像的类别预测分布quL_w=Pexp(IuL_w),同时无标签强增强图像IuL_s输入基础模型Pbase得到经验特征FuL_s和对于输入图像的类别预测分布quL_s=Pbase(IuL_s);其中,经验特征FuL_w和FuL_s是分类网络模型中最后一个全连接层的输入向量;
使用作为伪标签,通过均方差损失得到无标签数据部分的损失函数Lusp。其中表示一个可学习的特征映射矩阵;I是单位矩阵;ε和β是超参数:
其中,表示一个掩模向量,它的大小和H的输出一致,满足max(quL_w)>η条件的位置值为1,其他为0。η表示置信度阈值,quL_w表示类别预测分布,当quL_w对某个类别的预测置信度大于η时,才采用这次的预测;
本发明的主要创新点主要体现在步骤五中。弱增强的无标签数据通过将经验模型获得经验特征,强增强的无标签数据经过基础模型得到基础特征。在一致性约束理论和群表示特征理论的约束下,我们要求经验特征和基础特征的协方差矩阵的迹尽可能相似,这就是我们的一致性约束。对于伪标签部分,我们对经验特征进行简单的线性变化,预测样本的类别,这部分预测类别作为伪标签,和使用基础特征的预测类别计算交叉熵损失。上述两部分结合,就构成了我们的无标签数据损失。
步骤五中的一致性约束部分的群表示方法的详细描述。我们可以将数据空间看做是一个可分的拓扑空间,在定义了度量方法后该空间是一个拓扑流形。数据空间上的一般特性从统计学上看,是数据在数据空间上的分布特性,该特征具有一定的对称性,包括平移不变性和旋转不变形等,因此我们将一个batch的无标签数据表示为矩阵(B表示batch大小,D表示特征维度),能够表征这些对称性或者数据空间的分布特征的是协方差矩阵。在空间上不同的样本矩阵对加法构成了一个群∑,根据群表示的特征标理论,∑是一个线性群,对于群元σ∈∑,其矩阵的迹为tr(σ),而对于不同的∑上群元,我们可以的得到一个关于矩阵迹的函数Λ∑,称为群∑的特征标。由此,我们可以将无标签数据空间的一般特征表示为协方差矩阵群∑的特征标函数。
步骤五中所述的一致性约束损失。在对无标签数据的学习过程中,一个batch的基础数据在映射f,即经验模型,下得到对应的基础特征表示,对应地在映射F下得到新的关于经验特征的协方差矩阵ω,同时一个batch的加强数据在映射f′,即基础模型,下得到对应的经验特征表示,对应地在映射F′下得到新的关于基础特征的协方差矩阵ω′。对应于一致性约束的要求我们要求这两个协方差矩阵尽可能相似地相似,即我们可以要求两个协方差矩阵间的Log-Euclidean距离d尽可能得小:
d(ω,ω′)=||log(tr(ω))-log(tr(ω′)))||F (6)
如果直接使用公式(6)计算一致性约束,我们需要在得到两种特征表示的基础上分别计算协方差矩阵,再计算协方差矩阵的迹。这种方法会增加不必要的的计算量,所以我们使用另一种较为简便的等价形式计算来计算一致性约束。首先,依然需要计算得到经验特征和基础特征根据特征的群表示方法,我们可以认为FuL_w和FuL_s属于不同的数据空间,并认为存在一个特征映射矩阵可以将基础特征映射到经验特征空间中,使得在相同的特征空间中尽可能相近,所以我们得到公式(7):
公式(7)要求FuL_w和FuL_s在优化过程中越来越接近,如果存在完美的优化情况,两者最后应该变得相等,但是这两者是同一个batch的数据通过添加不同的数据增强方法得到了特征向量,让两者逐相等是一个很难实现,并且太强的约束条件,所以我们通过添加一个较小的偏置ε来中和这种情况,得到如公式(8):
所以我们对于一致性约束的最终优化目标如公式(9)所示:
步骤五中所述的伪标签方法损失。对于无标签数据,我们会为每一个样本计算一个人工标签,这个标签将用于计算无标签数据的标准交叉熵。为了得到人工标签,我们将借助之前得到的经验模型。首先计算经验模型对于基础数据的类别预测分布quL_s=Pbase(IuL_s),然后我们就可以使用作为伪标签。到这里我们已经可以得出所需要的伪标签目标函数,如公式(10)所示:
其中,η是一个表示阈值的标量超参数,预测概率高于阈值的我们将其保留作为一个伪标签。
步骤六:通过结合步骤三得到的有标签数据的损失函数Lsup和步骤五得到的无标签数据的损失函数Lusp,得到基于半监督学习的分类方法的最终损失函数。其中,λ是一个超参数,表示无标签数据损失所占的权重:
L=Lsup+λ·Lusp (4)
其中,第一部分是有标签数据的分类损失,第二部分是无标签数据的损失。这将结合权利要求1中的步骤三、步骤四、步骤五、步骤六完成,步骤三计算有标签数据的损失函数,步骤五计算无标签数据的损失函数,步骤四描述模型参数的优化过程。
步骤七:训练N个完整的周期epoch,用训练完成的基础模型作为最后的分类器;
步骤八:利用最后的分类器对新的图像进行分类。
所述的步骤二中,分类网络模型是使用半监督学习训练方法学习到的,半监督学习是一种不需要大量有标签数据就可以在大规模数据集上进行训练的有效方法,它通过允许模型学习未标记数据来大大减少对有标记数据的需求。半监督学习方法通常根据未标记数据对目标函数添加损失项,鼓励模型更好地归纳学习未标记数据的特征分布。一致性约束和伪标签是两种最常用半监督学习的方法。伪标签方法将模型对无标签数据的预测作为无标签数据的标签进行训练,而一致性约束方法将模型对无标签数据的预测分布作为标签进行训练。两种方法实现策略不同,但是都需要对同一张图像的不同角度的视图进行学习,所以对应于步骤一,我们对所有图像进行弱增强和强增强已获得同一张图像的不同视图。
所述的步骤二中,WiderResnet是残差网络ResNet的一种变体。传统的卷积网络或者全连接网络在信息传递的时候或多或少会存在信息丢失,损耗等问题,同时还有导致梯度消失或者梯度爆炸,导致很深的网络无法训练。残差网络ResNe在一定程度上解决了这个问题,通过直接将输入信息绕道传到输出,保护信息的完整性,整个网络只需要学习输入、输出差别的那一部分,简化学习目标和难度。ResNet最大的区别在于有很多的旁路将输入直接连接到后面的层,这种结构也被称为直连shortcut或者跳跃连接skip connections。本发明使用ResNet50的网络模型,ResNet50分为5个阶段,阶段0的结构比较简单,由一个7×7的卷积层和一个最大池化层组成,相当于对输入图像进行的预处理。后面4个阶段都由瓶颈层BottleNeck组成,结构都较为相近。阶段1包含3个BottleNeck,剩下的三个阶段分别包含4,6,3个BottleNeck。每个BottleNeck是由1×1,3×3和1×1的卷积网络串联在一起。ResNet的跳跃连接,导致只有少量的BottleNeck学习到有用的信息,所以通过在ResNet的基础上减少深度,增加宽度,得到一种新的网络WiderResNet。WiderResNet增大每一个BottleNeck中的卷积核的数量,通过宽度因子参数表示增大宽度的大小,宽度因子越大网络越宽,另外WiderResNet还在卷积层之间加入抛弃层Dropout。在本发明中,我们将阶段1中的卷积层的卷积核改为3×3。WiderResNet的深度因子为28,宽度因子为10。在阶段4之后添加一个全局平均值化层和两个全连接层。第一个全连接层为中间特征输出层,输出中间特征,对应于经验模型的到的是经验特征,对应于基础模型的是基础特征;第二个全连接层为类别预测层,以中间特征作为输入,其输出特征经过一个Softmax函数作为类别预测概率。
与现有技术相比,本公开的有益效果为:
1.本发明方法对于有标签样本的需求量少;
2.本发明方法通过基于群表示的特征表示方法,提高了半监督分类方法的准确率。
附图说明
图1为模型学习过程的图解;
图2为模型超参数选择实验结果,(a)CIFAR-10时错误率与一致性约束中的偏置关系图;(b)CIFAR-100时错误率与一致性约束中的偏置关系图;(c)SVHN时错误率与一致性约束中的偏置关系图;(d)CIFAR-10时错误率与置信度阈值关系图;(e)CIFAR-100时错误率与置信度阈值关系图;(f)SVHN时错误率与置信度阈值关系图。
具体实施方式
1.特征的群表示方法
我们可以将数据空间看做是一个可分的拓扑空间,在定义了度量方法后该空间是一个拓扑流形。数据空间上的一般特性从统计学上看,是数据在数据空间上的分布特性,该特征具有一定的对称性,包括平移不变性和旋转不变形等,因此我们将一个batch的无标签数据表示为矩阵(N表示batch大小,D表示特征维度),能够表征这些对称性或者数据空间的分布特征的是协方差矩阵。在空间上不同的样本矩阵对加法构成了一个群∑,根据群表示的特征标理论,∑是一个线性群,对于群元σ∈∑,其矩阵的迹为tr(σ),而对于不同的∑上群元,我们可以的得到一个关于矩阵迹的函数Λ∑,称为群∑的特征标。由此,我们可以将无标签数据空间的一般特征表示为协方差矩阵群∑的特征标函数。
2.经验模型
我们可以将一个数据特征生成算法表示成一个从到上的同构映射 f可以线性的或者是非线性的。因此所有的f也可以构成一个拓扑流形。进一步我们可以认为f是连续且可微的,因此所有的f又构成一个微分流形,同时也是一个李群。
对于上的协方差矩阵群∑,由于f在无标签数据数据的特征上的作用,产生数据样本在新的特征空间上的表示因此在与f同构的映射的作用下,产生协方差矩阵群∑′。由于f一般是非线性映射,群∑和群∑′一般也不是线性同构,或者说不是同一个群的两个不同的线性表示,因此根据特征标理论,Λ∑≠Λ∑′。
要直接求解映射f是一件很困难的事情,但是我们可以使用神经网络来将f拟合出来。所以现在问题变成:构建一个映射f,使得我们对其输入无标签数据的时候可以得到一个具有辨别性与普遍性的特征图。在半监督学习中,我们会同时对有标签数据和无标签数据进行学习,学习有标签数据是为了学习到更加准确的特征提取方法,作为无标签数据特征提取的基础。所以我们会在基础模型的基础上得到经验模型,我们选择对基础模型权重取平均来构建经验模型。之所以不直接用基础模型的最终参数,因为我们一般只会使用比较少的有标签数据,基础模型在有标签数据上会很快收敛并且达到过拟合状态,如果直接使用基础模型最终参数会影响经验模型的泛化能力,所以们在训练步骤上平均基础模型权重会得到更准确的经验模型。
3.混合的损失函数
我们将分类模型的损失函数分为两个部分,一部分是对有标签数据的损失函数Lsup和对无标签数据的损失函数Lusp。Lsup仅仅是有标签数据的标准交叉熵损失,Lusp由两部分组成,分别是一致性约束损失和伪标签损失。
4.一致性约束损失
在对无标签数据的学习过程中,一个batch的基础数据在映射f,即经验模型,下得到对应的基础特征表示,对应地在映射F下得到新的关于经验特征的协方差矩阵ω,同时一个batch的加强数据在映射f′,即基础模型,下得到对应的经验特征表示,对应地在映射F′下得到新的关于基础特征的协方差矩阵ω′。对应于一致性约束的要求我们要求这两个协方差矩阵尽可能相似地相似,即我们可以要求两个协方差矩阵间的Log-Euclidean距离尽可能得小:
如果直接使用公式(1)计算一致性约束,我们需要在得到两种特征表示的基础上分别计算协方差矩阵,再计算协方差矩阵的迹。这种方法会增加不必要的的计算量,所以我们使用另一种较为简便的等价形式计算来计算一致性约束。首先,依然需要计算得到经验特征表示和基础特征表示根据特征的群表示方法,我们可以认为φ和ψ属于不同的数据空间,并认为存在一个特征映射矩阵可以将基础特征映射到经验特征空间中,使得在相同的特征空间中尽可能相近,所以我们得到公式(2):
公式(1)要求φ和ψ在优化过程中越来越接近,如果存在完美的优化情况,两者最后应该变得相等,但是这两者是同一个batch的数据通过添加不同的数据增强方法得到了特征向量,让两者逐相等是一个很难实现,并且太强的约束条件,所以我们通过添加一个较小的偏置ε来中和这种情况,得到如公式(3):
所以我们对于一致性约束的最终优化目标如公式(4)所示:
5.伪标签
对于无标签数据,我们会为每一个样本计算一个人工标签,这个标签将用于计算无标签数据的标准交叉熵。为了得到人工标签,我们将借助之前得到的经验模型。首先计算经验模型对于基础数据的类别预测分布qb=Pexp(y|α(ub)),然后我们就可以使用作为伪标签。到这里我们已经可以得出所需要的伪标签目标函数,如公式(5)所示:
其中,η是一个表示阈值的标量超参数,预测概率高于阈值的我们将其保留作为一个伪标签。
实际案例
1.标准数据集
首先,我们将比较本发明方法和现有的方法在半监督学习基准数据集(CIFAR-10,CIFAR-100,SVHN)上的表现。CIFAR-10包含训练集50000张图片,测试集10000张图片,共包括10种类别。CIFAR-100包含训练集50000张图片,测试集10000张图片,共包括100个类别。SVHN包含训练集73257张图片,测试集26032张图片,共包括10个类别。
2.实验环境及参数设置
实验室CPU型号为I7-5930k,内存为32GB,显卡为GeForce 1080Ti,显存为11GB。我们使用Pytorch编写模型,使用不带Nesterov动量的SGD作为优化方法,初始化学习率为1e-2,weight decay是1e-3,使用Cosine学习率衰减策略。对于所有的超参数,在CIFAR-10和SVHN中,λ=20.0,η=0.95,α=0.97,ε=0.25,β=1.0。在CIFAR-100中,λ=20.0,η=0.95,α=0.97,ε=0.15,β=1.0。
3.实验结果
在CIFAR-10中,本发明方法在每类使用400张有监督图片是获得最低的错误率,3.55%。ReMixMatch,UDA和FixMatch在使用250和500张有标签图片的情况下都可以获得比较优秀的结果,但是可以看到的是本发明方法可以获得更稳定,优秀的结果,本发明方法仅在只是用250张有标签图片,错误率略高于FixMatch 0.01%,但是却优于其他方法,并且在其他所有情况中本发明方法都展现出最优的结果。在SVHN数据集中,本发明方法在每类使用400张有监督图片时获得最低的错误率,2.27%。本发明方法可以获得三项的最好成绩,但是在其他两项实验中,我们的方法任然可以获得第二位的性能,并且错误率仅略高于最好情况。
除了CIFAR-100中,我们可以发现本发明方法都可以取得非常满意的性能,但是在CIFAR-100中ReMixMatch得到出在所有情况下的最优性能。
表(1):CIFAR-10上5种不同的无标签数据使用数量下的错误率。
所有的基准模型都是用相同的代码库进行测试
表(2):CIFAR-100上4中不同的无标签数据使用数量下的错误率。
所有的基准模型都是用相同的代码库进行测试。
表(3):SVHN上5中不同的无标签数据使用数量下的错误率。
所有基准模型都是用相同的代码库进行测试。