基于一致性训练的半监督三维形状识别方法
技术领域
本发明涉及计算机视觉技术,具体是涉及一种基于一致性训练的半监督三维形状识别方法。
背景技术
三维视觉的研究对于自动驾驶、增强现实、机器人等应用有着重要的作用。随着深度学习的快速发展,研究者们提出很多用于三维形状识别任务的方法。当前主流的三维形状识别方法主要分为三种。第一种是基于多视角的方法,该方法将点云投影到多个二维视角,然后直接使用经典的二维卷积神经网络进行处理。投影得到的任一视角都会单独被二维卷积神经网络处理,然后使用视角-池化层将产生的各个视角的特征进行融合。多视角方法由于自遮挡会丢失一些关键的信息。第二种是基于体素的方法,该方法将点云体素化成规则的三维网格,然后使用三维卷积和池化操作进行处理,这会消耗大量的时间和空间资源。三维网格的稀疏性也会造成资源的浪费。最近几年基于点云的方法受到很多关注,该方法直接利用原始的点云数据作为输入。其中,Qi,C.等人提出的方法(Qi,C.,HaoSu,KaichunMo,et al.“PointNet:Deep Learning on Point Sets for 3D Classificationand Segmentation.”2017IEEE Conference on Computer Vision and PatternRecognition(CVPR)(2017):77-85.)是直接处理原始点云数据的开拓者,它对每一个点进行单独编码,最后使用全局池化来聚集所有点的特征信息。但是它不能捕获到三维物体的局部细节。因此,Qi,C.等人(Qi,C.,L.Yi,Hao Su,et al.“PointNet++:Deep HierarchicalFeature Learning on Point Sets in aMetric Space.”NIPS(2017).)又提出一个分层神经网络来提取局部特征。Wang,Yue等人(Wang,Yue,Yongbin Sun,Z.Liu,et al.“DynamicGraph CNN for Learning on Point Clouds.”ACM Transactions on Graphics(TOG)38(2019):1-12.)提出边卷积操作,并在卷积过程会动态的更新局部分组。上述方法实现很好的性能,但这些方法都是基于全监督设置的,需要大量的有标签数据。
点云数据研究的成功主要归功于强大的卷积神经网络和大量的有标签点云数据。尽管大多数方法致力于提升模型本身的准确率,获取大规模的有标签数据集也是一个棘手的问题。目前,由于深度传感器的进步,点云数据的获取变得更加方便和便宜。由于数据标注需要大量的人力,并且需要标注人员具有很强的专业知识,这使得获取有标签点云数据的成本非常昂贵。
半监督学习通过利用少量的有标签数据和大量的无标签数据来解决这一问题。近几年,半监督学习在二维图像处理上取得很大成就,实现与监督方法可比的性能。但是,用于三维点云分类的半监督方法屈指可数。Song,Mofei等人提出的方法(Song,Mofei,Y.Liuand Xiao Fan Liu.“Semi-Supervised 3D Shape Recognition via Multimodal DeepCo-training.”Computer Graphics Forum 39(2020):n.pag.)是第一个用于三维形状分类的半监督方法。该方法使用多模态网络进行协同训练,需要同时基于点云数据和多视角数据的两个分类网络进行训练。因此它需要两种数据表示才能训练,这使得训练集的数据获取变得更加困难。
发明内容
本发明的目的在于针对现有技术存在的上述问题,提供实现对有限的有标签数据集的扩充,结合一致性约束分支和伪标签生成分支训练深度模型,利用训练好的模型进行三维形状分类一种基于一致性训练的半监督三维形状识别方法。
本发明包括以下步骤:
A.准备三维形状数据集,包括有标签数据集和无标签数据集;
B.对无标签数据添加微小扰动得到扰动版本的无标签数据集;
C.设计一致性约束分支鼓励模型对相似样本预测一致,提高模型的泛化能力;
D.设计伪标签生成分支为无标签数据生成伪标签,并提出一致性过滤机制过滤掉模型不确定的伪标签,实现对有标签数据集的扩充;
E.结合有标签数据和无标签数据训练模型,得到训练好的模型;
F.使用训练好的模型进行三维形状识别,将模型的预测作为最终的识别结果。
在步骤A中,所述准备三维形状数据集进一步包括以下子步骤:
A1.准备有标签数据集,使用Dl={(xi,yi):i∈(1,...,m)}表示有标签数据,其中xi∈RN×F表示由N个具有F维特征的点组成的三维形状,yi∈{1,...,C}表示数据xi的类别标签,C表示数据集中包含的三维形状的总类别数,m表示有标签数据的数量;
A2.准备无标签数据集,使用Du={xj:j∈(1,...,n)}示无标签数据,其中xj∈RN×F示由N个具有F维特征的点组成的三维形状,n表示无标签数据的数量。
在步骤B中,所述对无标签数据添加微小扰动得到扰动版本的无标签数据集进一步包括以下子步骤:
B1.将微小扰动r添加到三维形状的xyz坐标信息上,对三维形状造成轻微形变且不会改变三维形状的类别语义;由于三维形状的大小不一,若对所有的三维形状添加同样大小的扰动,可能会造成一些三维形状的严重变形,按三维形状的半径R对扰动进行缩放,得到扰动版本的无标签形状x'j,x'j的计算方法如下:
x'j=xj+R·r (1)
在步骤C中,所述设计一致性约束分支鼓励模型对相似样本预测一致进一步包括以下子步骤:
C1.由于有标签数据的数量有限,设计一个一致性约束分支来提高模型的泛化能力,该分支要求模型对于相似样本应该预测为相同类别,起到平滑模型的作用;对于原始的无标签点云数据x'j以及扰动版本的无标签点云数据x'j,模型的预测应该一致;使用模型预测原始无标签点云数据xj得到预测分布f(xj),使用模型预测扰动版本的无标签点云数据x'j得到预测分布f(x'j),一致性约束损失的计算公式如下:
其中,KL是Kullback-Leibler散度,用于衡量两个预测分布的差距。
在步骤D中,所述设计伪标签生成分支为无标签数据生成伪标签,并提出一致性过滤机制过滤掉模型不确定的伪标签,实现对有标签数据集的扩充进一步包括以下子步骤:
D1.使用当前模型对无标签数据xj进行预测得到f(xj),以预测分布的类别概率最大的类别作为该数据的伪标签yp=argmax(f(xj));
D2.提出一致性过滤机制来过滤模型不确定的伪标签,仅选择模型对原始点云数据和扰动版本的点云数据具有一致预测时,才将原始点云数据加入候选集;使用当前的模型对扰动版本的无标签数据x'j进行预测得到f(x'j),若argmax(f(xj))=argmax(f(x'j)),则将原始数据xj以及它的伪标签yp加入候选集;
D3.从候选集中挑选置信度大于一定阈值的带有伪标签的无标签数据加入最终的伪标签数据集Dp;
D4.伪标签数据集Dp中的数据将和有标签数据一起用于训练中监督损失的计算,监督损失的计算公式如下:
其中,β是超参数,表示伪标签数据的监督损失的相对权重。
在步骤E中,所述结合有标签数据和无标签数据训练模型,得到训练好的模型进一步包括以下子步骤:
E1.模型的总损失函数是一致性损失函数和监督损失函数的总和,计算方法如下:
lsum=lsup+α·lcon (4)
其中,α是超参数;
E2.结合一致性约束分支和伪标签生成分支进行训练,得到训练好的模型进行三维形状识别。
本发明建立深度模型,包括一致性约束分支和伪标签生成分支;首先准备三维形状数据集,包括有标签数据集和无标签数据集。并对无标签数据添加微小扰动得到扰动版本的无标签数据集。使用设计的一致性约束分支来提高模型的泛化能力。使用设计的伪标签生成分支为无标签数据生成伪标签,并提出一致性过滤机制过滤掉模型不确定的伪标签,从而扩充有标签数据集。结合一致性约束分支和伪标签生成分支训练深度模型,利用训练好的模型进行三维形状分类。
附图说明
图1为本发明实施例的半监督三维形状识别框架示意图。
图2为在三维形状数据集ModelNet40上,本发明提出的半监督方法与监督方法在不同比例的有标签数据情况下的结果对比。
具体实施方式
以下结合附图和实施例对本发明的方法作详细说明,本实施例在以本发明技术方案为前提下进行实施,给出实施方式和具体操作过程,但本发明的保护范围不限于下述的实施例。
本发明首先准备有标签三维形状数据集和无标签三维形状数据集,通过给无标签数据添加微小扰动得到扰动版本的无标签三维形状数据集。使用一致性约束分支鼓励模型对于原始的无标签形状和扰动版本的无标签形状预测一致,从而提高模型的泛化能力。使用伪标签生成分支为无标签数据生成伪标签,并提出一致性过滤机制过滤掉模型不确定的伪标签,实现对有限的有标签数据集的扩充。结合有标签数据和无标签数据进行训练,得到训练好的模型进行三维形状识别。
参见图1和2,本发明实施例的实施方式包括以下步骤:
1.准备三维形状数据集,包括有标签数据集和无标签数据集。
A.采用三维形状基准数据集ModelNet40(Wu,Zhirong,Shuran Song,A.Khosla,etal.“3D ShapeNets:A deep representation for volumetric shapes.”2015IEEEConference on Computer Vision and Pattern Recognition(CVPR)(2015):1912-1920.),ModelNet40有12311个形状,共有40个类别,其中9843个形状用于训练,3991个形状的用于验证。
B.从训练集中随机采样10%的数据以及它的标签作为有标签数据,使用Dl={(xi,yi):i∈(1,...,m)}表示有标签数据集,其中xi∈R1024×3表示由1024个仅带有xyz坐标信息的点组成的三维形状,yi∈{1,...,C}表示数据xi的类别标签,C表示数据集中包含的三维形状的总类别数,m表示有标签数据的数量。
C.将训练集中的所有数据作为无标签数据,使用Du={xj:j∈(1,...,n)}表示无标签数据集,其中xj∈R1024×3表示由1024个仅带有xyz坐标信息的点组成的三维形状,n表示无标签数据的数量。
2.给无标签数据添加微小扰动获得扰动版本的无标签数据集。
A.采用虚拟对抗扰动(Miyato,Takeru,S.Maeda,Masanori Koyama and S.Ishii.“Virtual Adversarial Training:A Regularization Method for Supervised andSemi-Supervised Learning.”IEEE Transactions on Pattern Analysis and MachineIntelligence 41(2019):1979-1993.)作为添加的微小扰动r。
B.将微小扰动r添加到无标签三维点云数据xj的xyz坐标上,对三维形状造成轻微形变,且不会改变三维形状的类别语义。由于三维形状的大小不一,若对所有的三维形状添加相同大小的虚拟对抗扰动,可能会改变有些三维形状的类别语义,因此根据三维形状的半径R来对扰动进行缩放,最终得到扰动版本的无标签点云数据x'j,x'j计算方法如下:
x'j=xj+R·r (1)
3.设计一致性约束分支。
A.由于有标签数据的数量有限,直接利用有标签数据进行训练,很容易导致模型过拟合。因此,设计一个一致性约束分支来提高模型的泛化能力。该分支要求模型对于相似样本应该预测为相同类别,起到平滑模型的作用。对于无标签数据xj,使用模型预测得到预测结果为f(xj),对于扰动版本的无标签数据x'j,使用模型预测得到f(x'j),一致性损失函数的计算公式如下:
其中KL是Kullback-Leibler散度。
4.设计伪标签生成分支为无标签数据生成伪标签,并提出一致性过滤机制过滤掉模型不确定的伪标签,实现对有标签数据集的扩充。
A.使用当前模型对无标签数据xj进行预测得到f(xj),以预测分布的类别概率最大的类别作为该数据的伪标签yp=argmax(f(xj))。
B.由于一开始模型的性能较差,会产生很多错误的伪标签。若直接将大量的错误伪标签用于训练,会导致噪声训练。因此提出一致性过滤机制来过滤模型不确定的伪标签,仅选择模型对原始点云数据和扰动版本的点云数据具有一致预测时,才将原始点云数据加入候选集。使用当前的模型对扰动版本的无标签数据x'j进行预测得到f(x'j),若argmax(f(xj))=argmax(f(x'j)),则将原始数据xj以及它的伪标签yp加入候选集。
C.然后从候选集中挑选置信度大于一定阈值的带有伪标签的无标签数据加入最终的伪标签数据集Dp。
D.伪标签数据集Dp中的数据将和有标签数据一起用于训练中监督损失的计算,监督损失的计算公式如下:
其中β是超参数,表示伪标签数据的监督损失的相对权重。
5.结合有标签数据和无标签数据训练模型。
A.模型的总损失函数是一致性损失函数和监督损失函数的总和,计算方法如下:
lsum=lsup+α·lcon (4)
其中α是超参数。
B.结合一致性约束分支和伪标签生成分支进行训练,得到训练好的模型进行三维形状识别。
表1为在三维形状数据集ModelNet40上,本发明提出的半监督方法与其他方法的结果对比。可见,本发明方法相比其他方法准确率较高。
表1
在表1中,其他方法如下:
OctNet对应Riegler,G.,等人提出的方法(Riegler,G.,Ali O.Ulusoy andAndreas Geiger.“OctNet:Learning Deep 3D Representations at High Resolutions.”2017IEEE Conference on Computer Vision and Pattern Recognition(CVPR)(2017):6620-6629.)
MVCNN对应Su,Hang,等人提出的(Su,Hang,Subhransu Maji,E.Kalogerakis,etal.“Multi-view Convolutional Neural Networks for 3D Shape Recognition.”2015IEEE International Conference on Computer Vision(ICCV)(2015):945-953.)
PointNet对应Qi,C.等人提出的方法(Qi,C.,Hao Su,KaichunMo,et al.“PointNet:Deep Learning on Point Sets for 3D Classification andSegmentation.”2017IEEE Conference on Computer Vision and Pattern Recognition(CVPR)(2017):77-85.)
PointNet++对应Qi,C等人提出的方法(Qi,C.,L.Yi,Hao Su,et al.“PointNet++:Deep Hierarchical Feature Learning on Point Sets in a Metric Space.”NIPS(2017).)
DGCNN对应Wang,Yue等人提出的方法(Wang,Yue,Yongbin Sun,Z.Liu,et al.“Dynamic Graph CNN for Learning on Point Clouds.”ACM Transactions on Graphics(TOG)38(2019):1-12.)
FoldingNet对应Yang,Y.等人提出的方法(Yang,Y.,Chen Feng,Y.Shen,et al.“FoldingNet:Point Cloud Auto-Encoder via Deep Grid Deformation.”2018IEEE/CVFConference on Computer Vision and Pattern Recognition(2018):206-215.)
PointGLR对应Rao,Yongming等人提出的方法(Rao,Yongming,Jiwen Lu andJ.Zhou.“Global-Local Bidirectional Reasoning for Unsupervised RepresentationLearning of 3D Point Clouds.”2020IEEE/CVF Conference on Computer Vision andPattern Recognition(CVPR)(2020):5375-5384.)
MDC对应Song,Mofei等人提出的方法(Song,Mofei,Y.Liu and Xiao Fan Liu.“Semi-Supervised 3D Shape Recognition via Multimodal Deep Co-training.”Computer Graphics Forum 39(2020):n.pag.)
本发明仅需要点云这一种数据表示。为减少数据标注的成本,仅使用10%的有标签数据。为避免模型在有限的有标签数据上过拟合,本发明提出一个一致性约束分支来提高模型的泛化能力。此外,还为无标签数据生成伪标签来扩充现有的有标签数据。在一致性约束和伪标签的共同作用下,更好利用无标签点云数据,有效减少分类模型对有标签数据的需求。