跨样本联邦分类建模方法及装置、存储介质、电子设备

文档序号:8211 发布日期:2021-09-17 浏览:32次 英文

跨样本联邦分类建模方法及装置、存储介质、电子设备

技术领域

本公开涉及联邦学习

技术领域

,尤其涉及一种基于神经网络和知识蒸馏的跨样本联邦建模方法及装置、存储介质、电子设备。

背景技术

随着深度学习研究的深入以及计算机设备的发展,人工神经网络被广泛应用于计算机人工智能领域。而为了保证训练出的人工神经网络具备良好的性能,通常需要大量数据投入训练。

但是,在某些场景下训练数据散落在不同的组织或机构中,出于数据隐私考虑,无法通过数据共享的方式满足人工神经网络的训练需求。即使能够通过共享的数据训练人工神经网络,那么对训练数据的通信和传输复杂度等要求也难以满足,更无法保障人工神经网络的训练效果。

鉴于此,本发明提出了一种基于神经网络和知识蒸馏的跨样本联邦分类建模方法及装置。

发明内容

本公开的目的在于提供一种基于神经网络和知识蒸馏的跨样本联邦分类建模方法、装置、计算机可读存储介质及电子设备,有利于在确保各个联邦参与方数据隐私的情况下,即通过共享的数据训练人工神经网络,在满足训练数据的通信和传输复杂度要求的基础上,同时保障人工神经网络的训练效果。

本公开的其他特性和优点将通过下面的详细描述变得显然,或部分地通过本公开的实践而习得。

根据本发明实施例的第一个方面,提供一种基于神经网络和知识蒸馏的跨样本联邦分类建模方法,所述方法包括:

获取待联邦分类建模任务制定的标签标准信息,并根据所述标签标准信息以及本地训练数据对所述联邦分类建模任务对本地神经网络模型进行结构自定义和参数初始化处理;

对所述本地神经网络模型进行训练得到所述本地训练数据的预测标签向量,并对所述预测标签向量进行知识蒸馏处理得到本地每个类别的软标签向量;

获取与所述联邦分类建模任务对应的联邦建模参数,并根据所述联邦建模参数将所述本地每个类别的软标签向量发送至协调方;

接收所述协调方根据所述软标签向量返回的联邦标签向量,并根据所述联邦标签向量和所述本地训练数据对所述本地神经网络模型继续进行训练,以完成训练得到联邦分类建模模型。

在本发明的一种示例性实施例中,

所述对所述本地神经网络模型进行训练得到所述本地训练数据的预测标签向量,包括:

获取用于训练所述联邦分类建模任务的原始训练数据,并对所述原始训练数据进行标签对齐处理和数据过滤处理,得到本地目标训练数据;

利用所述本地目标训练数据对所述本地神经网络模型进行训练,得到所述本地目标训练数据的预测标签向量。

在本发明的一种示例性实施例中,所述利用所述本地目标训练数据对所述本地神经网络模型进行训练得到所述本地目标训练数据的预测标签向量,包括:

对所述本地目标训练数据进行数据划分得到本地训练数据集,并对所述本地训练数据集进行数据划分得到多组待训练数据;

利用多组所述待训练数据对所述本地神经网络模型进行迭代训练,得到所述本地目标训练数据的预测标签向量。

在本发明的一种示例性实施例中,

所述对所述预测标签向量进行知识蒸馏处理得到本地每个类别的软标签向量,包括:

获取与知识蒸馏相关的温度参数,并对所述本地训练数据的预测标签向量和所述温度参数进行知识蒸馏计算,得到所述本地训练数据的预测标签的蒸馏向量;

对同类别的所述本地训练数据的预测标签的蒸馏向量进行平均值计算得到本地每个类别的软标签向量。

在本发明的一种示例性实施例中,

所述联邦建模参数包括联邦训练轮数,所述根据所述联邦标签向量和所述本地训练数据对所述本地神经网络模型继续进行训练,以完成训练得到联邦分类建模模型,包括:

获取与所述本地训练数据对应的标签数据,并对所述预测标签向量和所述标签数据进行损失计算,得到第一损失值;

对所述预测标签向量和所述联邦标签向量进行损失计算,得到第二损失值,并根据所述第一损失值和所述第二损失值对所述本地神经网络模型进行更新;

对更新后的所述本地神经网络模型继续进行训练,直至所述神经网络模型的训练次数达到所述联邦训练轮数时,得到训练好的联邦分类建模模型。

在本发明的一种示例性实施例中,

所述联邦建模参数包括通信频率条件;

所述根据所述联邦建模参数将所述本地每个类别的软标签向量发送至协调方,包括:

当对所述本地神经网络模型的训练过程满足所述通信频率条件时,将所述本地每个类别的软标签向量发送至协调方。

在本发明的一种示例性实施例中,

所述根据所述标签标准信息以及本地训练数据对所述本地神经网络模型进行结构自定义处理和参数初始化处理,包括:

确定所述本地神经网络模型的标准结构信息,并按照所述标准结构信息对所述本地神经网络模型进行结构自定义,得到所述本地神经网络模型的网络结构信息;

对所述本地神经网络模型进行参数初始化处理。

根据本发明实施例的第二个方面,提供一种基于神经网络和知识蒸馏的联邦分类建模装置,所述装置包括:

模型定义模块,被配置为获取待联邦分类建模任务制定的标签标准信息,并根据所述标签标准信息以及本地训练数据对所述联邦分类建模任务对本地神经网络模型进行结构自定义和参数初始化处理;

模型训练模块,被配置为对所述本地神经网络模型进行训练得到所述本地训练数据的预测标签向量,并对所述预测标签向量进行知识蒸馏处理得到本地每个类别的软标签向量;

向量发送模块,被配置为获取与所述联邦分类建模任务对应的联邦建模参数,并根据所述联邦建模参数将所述本地每个类别的软标签向量发送至协调方;

训练完成模块,被配置为接收所述协调方根据所述软标签向量返回的联邦标签向量,并根据所述联邦标签向量和所述本地训练数据对所述本地神经网络模型继续进行训练,以完成训练得到联邦分类建模模型。

根据本发明实施例的第三个方面,提供一种电子设备,包括:处理器和存储器;其中,存储器上存储有计算机可读指令,所述计算机可读指令被所述处理器执行时实现上述任意示例性实施例中的基于神经网络和知识蒸馏的跨样本联邦分类建模方法。

根据本发明实施例的第四个方面,提供一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现上述任意示例性实施例中的基于神经网络和知识蒸馏的跨样本联邦分类建模方法。

由上述技术方案可知,本公开示例性实施例中的一种基于神经网络和知识蒸馏的跨样本联邦分类建模方法、装置、计算机存储介质及电子设备至少具备以下优点和积极效果:

在本公开的示例性实施例提供的方法及装置中,对本地神经网络模型进行结构自定义和参数初始化处理,能够有效防止本地神经网络模型过拟合或者是欠拟合的情况发生。进一步的,本地神经网络模型在本地进行训练的情况下,对预测标签向量进行知识蒸馏处理,学习了其他神经网络模型的训练数据量产生的分类能力,在极大降低了计算资源的基础上,训练效果更优。除此之外,在训练过程中,只需与协调方传输软标签向量和联邦标签向量,因此还能降低通信成本和传输成本。即通过共享的数据训练人工神经网络,在满足训练数据的通信和传输复杂度要求的基础上,同时保障人工神经网络的训练效果。

应当理解的是,以上的一般描述和后文的细节描述仅是示例性和解释性的,并不能限制本公开。

附图说明

此处的附图被并入说明书中并构成本说明书的一部分,示出了符合本公开的实施例,并与说明书一起用于解释本公开的原理。显而易见地,下面描述中的附图仅仅是本公开的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。

图1示意性示出本公开示例性实施例中一种基于神经网络和知识蒸馏的跨样本联邦分类建模方法的流程示意图;

图2示意性示出本公开示例性实施例中神经网络模型的节点示意图;

图3示意性示出本公开示例性实施例中复杂结构的神经网络模型的节点示意图;

图4示意性示出本公开示例性实施例中卷积神经网络的结构示意图;

图5示意性示出本公开示例性实施例中联邦分类建模任务的参与方进行结构自定义和参数初始化处理的方法的流程示意图;

图6示意性示出本公开示例性实施例中对本地神经网络模型进行训练的方法的流程示意图;

图7示意性示出本公开示例性实施例中利用目标训练数据对本地神经网络模型进行训练的方法的流程示意图;

图8示意性示出本公开示例性实施例中知识蒸馏处理的方法的流程示意图;

图9示意性示出本公开示例性实施例中每个联邦建模任务的参与方对本地神经网络模型继续训练的方法的流程示意图;

图10示意性示出本公开示例性实施例中应用场景下基于神经网络和知识蒸馏的跨样本联邦分类建模方法的流程示意图;

图11示意性示出本公开示例性实施例中应用场景下参与方A的网络结构的结构示意图;

图12示意性示出本公开示例性实施例中应用场景下参与方B的网络结构的结构示意图;

图13示意性示出本公开示例性实施例中应用场景下参与方C的网络结构的结构示意图;

图14示意性示出本公开示例性实施例中一种基于神经网络和知识蒸馏的跨样本联邦分类建模装置的结构示意图;

图15示意性示出本公开示例性实施例中一种用于实现基于神经网络和知识蒸馏的跨样本联邦分类建模的电子设备的结构示意图;

图16示意性示出本公开示例性实施例中一种用于实现基于神经网络和知识蒸馏的跨样本联邦分类建模的计算机可读存储介质的结构示意图。

具体实施方式

现在将参考附图更全面地描述示例实施方式。然而,示例实施方式能够以多种形式实施,且不应被理解为限于在此阐述的范例;相反,提供这些实施方式使得本公开将更加全面和完整,并将示例实施方式的构思全面地传达给本领域的技术人员。所描述的特征、结构或特性可以以任何合适的方式结合在一个或更多实施方式中。在下面的描述中,提供许多具体细节从而给出对本公开的实施方式的充分理解。然而,本领域技术人员将意识到,可以实践本公开的技术方案而省略所述特定细节中的一个或更多,或者可以采用其它的方法、组元、装置、步骤等。在其它情况下,不详细示出或描述公知技术方案以避免喧宾夺主而使得本公开的各方面变得模糊。

本说明书中使用用语“一个”、“一”、“该”和“所述”用以表示存在一个或多个要素/组成部分/等;用语“包括”和“具有”用以表示开放式的包括在内的意思并且是指除了列出的要素/组成部分/等之外还可存在另外的要素/组成部分/等;用语“第一”和“第二”等仅作为标记使用,不是对其对象的数量限制。

此外,附图仅为本公开的示意性图解,并非一定是按比例绘制。图中相同的附图标记表示相同或类似的部分,因而将省略对它们的重复描述。附图中所示的一些方框图是功能实体,不一定必须与物理或逻辑上独立的实体相对应。

针对相关技术中存在的问题,本公开提出了一种基于神经网络和知识蒸馏的跨样本联邦分类建模方法,通过第一终端设备提供图形用户界面,图形用户界面显示至少部分的游戏场景。图1示出了一种基于神经网络和知识蒸馏的跨样本联邦分类建模方法的流程图,如图1所示,基于神经网络和知识蒸馏的跨样本联邦分类建模方法至少包括以下步骤:

步骤S110.获取待联邦分类建模任务制定的标签标准信息,并根据标签标准信息以及本地训练数据对联邦分类建模任务的本地神经网络模型进行结构自定义和参数初始化处理。

步骤S120.对本地神经网络模型进行训练得到本地训练数据的预测标签向量,并对预测标签向量进行知识蒸馏处理,得到本地每个类别的软标签向量。

步骤S130.获取与联邦分类建模任务对应的联邦建模参数,并根据联邦建模参数将本地每个类别的软标签向量发送至协调方。

步骤S140.接收协调方根据软标签向量返回的联邦标签向量,并根据联邦标签向量和本地训练数据对本地神经网络模型继续进行训练,以完成训练得到联邦分类建模模型。

重复迭代步骤S120~S140,以完成训练得到联邦分类建模模型。

在本公开的示例性实施例中,参与方首先可以对本方的神经网络模型进行结构自定义和参数初始化处理,能够有效防止模型的过拟合和欠拟合情况发生。进而,参与方与协调方的数据传输和通信仅包括软标签向量和联邦标签向量,大大降低了通信和传输成本,并且,该通信内容也无需额外的加密和解密计算,加速了神经网络模型的训练速率和通信速率,也保障了参与方的数据安全。除此之外,通信内容无需集中化存储,从而降低对资源存储的要求。然后,根据联邦标签向量对神经网络模型继续进行训练,参与方将其他参与方作为教师网络,当本地所拥有的训练数据量较少或网络结构较简单时,可以学习到教师网络的分类能力,比仅仅使用本方数据独立训练得到的神经网络模型的性能更优。最后,参与方也只需在本地进行训练,降低了计算资源量。

下面对神经网络模型的训练方法的各个步骤进行详细说明。

在步骤S110中,获取待联邦分类建模任务制定的标签标准信息,并根据标签标准信息以及本地训练数据对联邦分类建模任务的本地神经网络模型进行结构自定义和参数初始化处理。

在本公开的示例性实施例中,人工神经网络(Artificial Neural Netwo rk),可简称为神经网络,是模仿生物神经网络的结构和功能的数学或计算模型。神经网络由节点联接而成,其中单个节点模仿生物神经元,接收单个或多个输入,产生一个输出。

图2示出了神经网络模型的节点示意图,如图2所示,该神经网络模型是由3个输入和1个输出组成的节点构成。

进一步的,多个节点按照不同的联接方式可以组成不同结构的复杂神经网络。图3示出了复杂结构的神经网络模型的节点示意图,如图3所示,该复杂神经网络模型由4个输入层、2个中间层和1个输出层共10个节点组成。

随着深度学习研究的深入以及计算设备的发展,卷积神经网络(Co nvolutionNeural Network)被广泛应用于计算机视觉等领域。卷积神经网络是既包含卷积计算,又具有深度结构的神经网络。

典型的卷积神经网络包含卷积层、池化层和全连接层。其中,卷积层通常负责提取图像中的局部特征,池化层用来降低参数量级,全连接层类似于神经网络层,输出想要的结果。

图4示出了卷积神经网络的结构示意图,如图4所示,该神经网络模型包括1个卷积层、1个池化层和2个全连接层。

针对联邦分类建模任务制定的标签标准,可以是由参与联邦分类建模任务的发起方制定的。在联邦分类建模任务中,参与方是所有请求加入本次联邦分类建模任务的数据提供方;发起方是参与方中的一个,并且,一次联邦分类建模任务中仅有一个发起方;协调方是在联邦分类建模任务发起之前由所有参与方协商确认的,可以是参与方中的一个,也可以是参与方之外的机构或者组织担任的,并且一次联邦分类建模任务也只有一个协调方。

因此,该标签标准信息是由发起方制定的标签标准的信息。该标签标准信息中定义了联邦分类建模任务中的分类类别以及各类别的统一标签。并且,发起方在联邦分类建模任务中可以向其他参与方进行联邦建模请求,并向同意加入联邦分类建模任务的参与方发送该标签标准信息,以要求所有同意参与联邦分类建模任务的参与方均根据发起方制定的标签标准进行标注和训练。

在可选的实施例中,图5示出了联邦分类建模的参与方进行结构自定义和参数初始化处理的方法的流程示意图,如图5所示,该方法至少包括以下步骤:在步骤S510中,确定神经网络模型的标准结构信息,并按照标准结构信息对本地神经网络模型进行结构自定义,得到本地神经网络模型的网络结构信息。每个参与方可以按照本地的数据的数量和标签标准信息对本地神经网络模型进行结构自定义。

该本地神经网络结构自定义是按照标准结构信息,亦即包含必要的神经网络模型结构的情况下,对各个层的节点和层数等自行进行规定,但是,在规定过程中必须遵循最后一层的输出层的节点个数与标签标准信息中能够定义的分类类别个数相等。

当参与方本地的数据较多,且标签标准信息中定义的分类类别较多时,可以定义相对复杂的神经网络模型,以便获取精度较高的神经网络模型;当参与方本地的数据较少时,可以定义简单的神经网络模型,防止模型欠拟合的情况发生。

举例而言,结构自定义得到的神经网络模型的网络结构信息可以是,该神经网络模型包括2个卷积层、2个池化层和2个全连接层。

在步骤S520中,对本地神经网络模型进行参数初始化处理。

在得到网络结构信息之后,还可以对网络结构进行参数初始化处理。该参数初始化处理是对本地神经网络模型的各个层的权重进行初始化设置,可以由每个参与方根据神经网络模型自行选择。

在本示例性实施例中,通过结构自定义处理和参数初始化处理可以使得参与方根据自身情况进行自主设置,防止模型欠拟合的情况发生,也能够设置出精度较高的神经网络模型,保证了后续的训练效果。

在步骤S120中,对本地神经网络模型进行训练得到本地训练数据的预测标签向量,并对预测标签向量进行知识蒸馏处理得到本地每个类别的软标签向量。

在本公开的示例性实施例中,每个联邦分类建模任务的参与方的本地神经网络模型的训练主要是通过计算网络输出与输入数据的真实标签之间的误差,并将该误差反向传播来迭代更新网络当中的参数,从而使得网络输出向着输入数据的真实标签尽可能地靠近。在预测过程中,向训练好的网络结构中输入无标签的数据时,神经网络可以输出数据的标签预测值。

在可选的实施例中,图6示出了对本地神经网络模型进行训练的方法的流程示意图,如图6所示,该方法至少包括以下步骤:在步骤S610中,获取用于训练联邦分类建模任务的原始训练数据,并对原始训练数据进行标签对齐处理和数据过滤处理,得到本地目标训练数据。

该原始训练数据是储存在各个参与方的本地,且符合此次联邦分类建模任务的数据。并且,为了避免对无关数据的处理,可以对原始训练数据进行标签对齐处理和数据过滤处理。

具体的,参与方可以将在此次标签标准内的原始训练数据进行标注,标注方式按照此次的标签标准规定即可。

除此之外,原始数据还可能包括其他类别的数据,而其他类的数据是不参与此次的联邦分类建模任务的。当参与方提供的原始训练数据包含除标签标准之外的其他训练数据时,可以在该参与方的原始数据中删除该类数据得到目标训练数据,亦即过滤掉无意义的训练数据,当参与方的所有原始训练数据均不在标签标准定义的类别内时,表明该参与方为此次联邦分类建模任务的无效参与方,需要退出此次的联邦分类建模任务,以避免无效参与方参与本次的联邦分类建模任务。

在步骤S620中,利用本地目标训练数据对本地神经网络模型进行训练,得到本地训练数据的预测标签向量。

在可选的实施例中,图7示出了利用目标训练数据对本地神经网络模型进行训练的方法的流程示意图,如图7所示,该方法至少包括以下步骤:在步骤S710中,对本地目标训练数据进行数据划分得到本地训练数据集,并对本地训练数据集进行数据划分得到多组待训练数据。

每个参与方可以对拥有的目标训练数据进行数据划分,得到训练数据集和验证数据集两部分。举例而言,可以将目标训练数据按照4:1的比例划分成训练数据集和验证数据集。

进一步的,还可以将训练数据集进行数据划分得到多组待训练数据。具体的,将整个训练样本分成若干个Batch(批/一批样本),该Batch的Batch Size(批大小,每批样本的大小)可以是由参与方自行设置的。

在步骤S720中,利用多组待训练数据对本地神经网络模型进行迭代训练,得到本地目标训练数据的预测标签向量。

在得到多组待训练数据之后,可以将待训练数据输入至神经网络模型中,得到对应数据的预测标签向量。

在本示例性实施例中,通过目标训练数据可以实现对神经网络模型的训练,参与方的数据均在本地参与模型训练,无需向其他任何组织或机构透漏数据和标签等信息,保障了各个参与方的数据安全和隐私信息。

在得到本地训练数据的预测标签向量之后,还可以对预测标签向量进行知识蒸馏处理,得到本地的软标签向量。

在本示例性实施例中,图8示出了知识蒸馏处理的方法的流程示意图,如图8所示,该方法至少包括以下步骤:在步骤S810中,获取与知识蒸馏相关的温度参数,并对本地训练数据的预测标签向量和温度参数进行知识蒸馏计算,得到本地训练数据的预测标签的蒸馏向量。

该温度参数为后续进行知识蒸馏计算的参数,可以自行设定。

知识蒸馏是指利用教师网络(Teacher Network)输出的软标签(Soft Target)作为学生网络(Student Network)误差计算的一部分输入,来指导学生网络的训练,使得学生网络可用简单的小型网络结构即可实现与复杂教师网络一样的能力,达到与复杂的大型教师网络一样或相当的结果,从而实现知识迁移和复杂模型的压缩。

其中,软标签包括教师网络对输入的样本数据进行处理后得到的概率数据,实质上是教师网络根据样本数据得到的预测数据。在知识蒸馏中,该预测数据可以用于学生网络的训练,其性质接近于标签数据,但与真实标签存在不同,因此称为软标签。

在神经网络结构中,通常会使用Softmax层使输出结果用概率的形式表现出来,如神经网络的输出向量为[3,0,-3],将输出向量中每个元素代入Softmax计算公式:

qi=exp(Zi)/∑jexp(Zj) (1)

在利用公式(1)进行计算后得到的转换值约为[0.95,0.0476,0.0024],即表示该向量属于第一类、第二类、第三类的概率分别为0.95、0.0476、0.0024。

进一步的,在设定温度参数的情况下,可以按照公式(2)对输出向量和温度参数进行知识蒸馏计算:

qi=exp(Zi/T)/∑jexp(Zj/T) (2)

其中,T即为温度参数,qi为样本蒸馏向量。样本蒸馏向量具有更高的熵值,能够提供更多的信息和更小的梯度方差,因此比原始的教师网络更容易训练,且使用的学习效率更高。

由于样本蒸馏向量的取值在0-1之间,因此,当T的值越大时,样本蒸馏向量的取值在0-1之间的分布更加缓和,亦即样本蒸馏向量越“软”,也表示越能发挥教师网络对学生网络的指导作用。这是因为相当于在迁移学习过程中添加了扰动,从而使得学生网络在借鉴学习的时候更有效,且泛化能力更强,是一种抑制过拟合的策略。

例如当T=3时,得到转换后的样本蒸馏向量约为[0.663,0.246,0.091],是一个比[0.95,0.0476,0.0024]分布更“软”的概率。因此在学生网络训练过程中,可以让教师网络使用一个较高的温度T使其输出一个较“软”的分布,让学生模型的输出近似教师网络,从而将教师网络中的知识提取出来,因此称为知识蒸馏。

在步骤S820中,对同类别的本地训练数据的预测标签的蒸馏向量进行平均值计算,得到本地每个类别的软标签向量。

在得到样本蒸馏向量之后,可以将真实标签为同一类的目标训练数据的样本预测标签的蒸馏向量进行平均值计算得到该类目标训练数据的软标签向量。该平均值计算可以是一般的求值计算,也可以是加权平均计算,还可以按照实际情况进行计算,本示例性实施例对此不做特殊限定。

在本示例性实施例中,利用知识蒸馏计算的方式将每个参与方看作其他参与方的教师网络,以在后续训练过程中指导其他参与方的模型训练,即使参与方的训练数据较少或网络结构较简单,也能够学习到其他参与方的教师网络的分类能力,比仅仅使用本方数据训练模型的性能更优。并且,在参与方拥有的训练数据类别不全,如缺少某一类数据的情况下,也能使该参与方的模型学习到该缺失类别的分类能力,补全了模型训练的能力训练,使得训练出的网络模型完整且准确,效果极佳。

在步骤S130中,获取与联邦分类建模任务对应的联邦建模参数,并根据联邦建模参数将本地每个类别的软标签向量发送至协调方。

在本公开的示例性实施例中,知识联邦旨在确保各个参与方数据在不离开本地的情况下,各参与方交换数据中的“知识”,从而建立一个充分利用各个参与方数据的模型,达到“数据可用不可见,知识共创可共享”的目的。

其中,“知识”可以理解为参与方的服务器或终端与协调方的终端或服务器之间传递的信息。举例而言,参与方的终端或服务器中的“知识”可以是根据本地的样本数据提取或计算得到的。

根据各个参与方数据分布的特点,知识联邦可分为跨特征联邦、跨样本联邦以及复合型联邦。其中,跨样本联邦是指每个训练方的数据具有相同的特征,但各方的参与建模的样本是独立的,而且每个训练方都有与自己样本对应的标签数据。

跨样本联邦建模的目的是为了在数据不出本地的情况下充分利用所有训练方的样本和标签数据,获取一个比仅仅使用本地数据训练的模型效果更优的联邦模型。

为实现各个参与方之间的知识联邦,发起方可以确定与本次知识联邦相关的参数,亦即联邦建模参数。该联邦建模参数可以包括联邦训练轮数和通信频率条件。

其中,联邦训练轮数即为Epoch(轮数)。当一个完整的数据集通过了神经网络一次并且返回了一次,这个过程称为一次epoch。也就是说,所有训练样本在神经网络中都进行了一次正向传播和一次反向传播。再通俗一点,一个Epoch就是将所有训练样本训练一次的过程。

通信频率条件规定了知识蒸馏处理和各参与方之间进行通讯的频率和条件。举例而言,通信频率条件可以是每个Epoch进行一次知识蒸馏处理和通信。

在可选的实施例中,联邦建模参数包括通信频率条件。当对本地神经网络模型的训练过程满足通信频率条件时,将本地每个类别的软标签向量发送至协调方。

各参与方将本地的软标签向量发送给协调方再进行通信,因此,只有当对神经网络模型的训练过程满足通信频率条件时,才可以发送本地软标签向量。

协调方根据通信频率条件接收到各个参与方发送来的软标签向量后,可以对各个参与方的软标签向量进行计算得到每个参与方的联邦标签向量。每个参与方的联邦标签向量为除自身外的其他参与方提供的对应类别软标签向量的均值。

进一步的,协调方将各个参与方的联邦标签向量再发送给对应的参与方。

在步骤S140中,接收协调方根据软标签向量返回的联邦标签向量,并根据联邦标签向量和本地训练数据对本地神经网络模型继续进行训练,以完成训练得到联邦分类建模模型。

在本公开的示例性实施例中,当协调方计算出每个参与方对应的联邦标签向量之后,可以将该联邦标签向量返回给参与方,因此,参与方可以接收到对应的联邦标签向量进行后续训练。

在可选的实施例中,联邦建模参数包括联邦训练轮数。图9示出了每个联邦分类建模任务的参与方对本地神经网络模型继续训练的方法的流程示意图,如图9所示,该方法至少包括以下步骤:在步骤S910中,获取与本地训练数据对应的标签数据,并对预测标签向量和标签数据进行损失计算,得到第一损失值。

值得说明的是,训练一个Batch就是一次Iteration(迭代)。为完成一次迭代,要通过输出向量与真实的标签数据之间的误差来迭代神经网络模型中的参数,从而使得神经网络模型的输出向量向着输入数据尽可能的靠近,以得到一个训练好的神经网络模型。

因此,每个联邦分类建模参与方获取到本地目标训练数据对应的标签数据,该标签数据即为对目标训练数据的真实标签。

进一步的,可以按照公式(3)计算输出向量与标签数据的第一损失值:

公式(3)为交叉熵损失函数。交叉熵(Cross Entropy)是香农信息论中一个重要概念,主要用于度量两个概率分布间的差异性信息。在信息论中,交叉熵是表示两个概率分布p,q的差异性信息,其中p表示真实分布,q表示非真实分布,在相同的一组事件中,其中,用非真实分布q来表示某个事件发生所需要的平均比特数。

交叉熵可在机器学习中作为损失函数,p代表真实标记的分布,q则代表训练后的模型的预测标记分布,交叉熵损失函数可以衡量p与q的相似性。交叉熵作为损失函数还有一个好处是使用sigmoid函数在梯度下降时,可以避免均方误差损失函数学习速率下降的问题,这是因为学习速率的下降速度是能够被输出的误差所控制的。

除此之外,对输出向量和标签数据进行损失计算得到第一损失值的方式也可以采用公式(3)实现,还可以采用其他计算方式,本示例性实施例对此不做特殊限定。

在步骤S920中,对预测标签向量和联邦标签向量进行损失计算,得到第二损失值,并根据第一损失值和第二损失值对本地神经网络模型进行更新。

具体的,可以对输出向量与联邦标签向量进行损失计算得到第二损失值,计算方式可以参照公式(3),也可以采用其他方式,本示例性实施例对此不做特殊限定。

进一步的,每个联邦分类建模参与方利用第一损失值和第二损失值两部分损失值作为本地神经网络模型更新的损失误差,对本地神经网络模型的参数进行更新。

在步骤S930中,对更新后的本地神经网络模型继续进行训练,直至神经网络模型的训练次数达到联邦训练轮数时,得到训练好的联邦分类建模模型。

在对神经网络模型更新结束之后,亦即完成了一次对神经网络模型的迭代训练。并且,可以对神经网络模型的训练次数加一。

进一步的,每个联邦分类建模参与方可以继续对更新后的本地神经网络模型进行训练。同样的,与本地神经网络模型的第一次训练一样,也要对继续训练的预测标签向量进行知识蒸馏处理,以及和协调方进行通信,更新本方的联邦标签向量,也要对该神经网络模型进行再一次的更新等处理。

在神经网络模型的训练次数不断统计的过程中,当神经网络模型的训练次数达到联邦训练轮数时,表明此次的联邦分类建模任务结束,亦即对神经网络模型的迭代训练结束,各个参与方可以得到训练好的神经网络模型,并退出训练。

在本示例性实施例中,通过损失值对神经网络模型的更新,可以迭代训练出训练好的神经网络模型,同时完成对多个神经网络模型的训练,训练方式灵活性更好,并且多个神经网络模型之间利用联邦标签向量可以相互影响,建模效果更佳。

下面结合一应用场景对本公开实施例中基于神经网络和知识蒸馏的跨样本联邦分类建模方法做出详细说明。

图10示出了应用场景下基于神经网络和知识蒸馏的跨样本联邦建模方法的流程示意图,如图10所示,在步骤S1010中,发起方发起任务并制定标签标准。

参与联邦分类建模任务的发起方可以制定联邦分类建模任务制定的标签标准。在联邦分类建模任务中,参与方是所有请求加入本次联邦分类建模任务的数据提供方;发起方是参与方中的一个,并且一次联邦分类建模任务中仅有一个发起方;协调方是在联邦分类建模任务发起之前由所有参与方协商确认的,可以是参与方中的一个,也可以是参与方之外的机构或者组织单人的,并且一次联邦分类建模任务也只有一个协调方。

举例而言,机构A有狗、猫、牛三种图像,但由于狗、牛类样本量为2000,猫的样本量20000,猫的样本数量大于狗和牛的样本数量,要训练一个三种动物的图像分类模型,首先需要建立标签标准,如表1所示:

表1

类别 标签
0
1
2

在该标签标准制定结束后可以得到标签标准信息。该标签标准信息中定义了此次联邦分类建模任务中的分类类别包括狗、猫和牛三种,并且狗的标签为0,猫的标签为1,牛的标签为2。

进一步的,机构A同时向机构B、C、D、E发送联邦建模请求,其中机构B、C、D同意加入联邦建模,机构E拒绝加入,故机构A将标签标准信息发送给同意加入联邦建模的机构B、C、D。

在步骤S1020中,参与方数据过滤与标签对齐。

机构A、B、C、D在本方拥有数据中收集标签标准中定义的狗、猫、牛类图像,其中每一方参与拥有的相关样本量(图像张数)统计如表2所示:

表2

除此之外,当参与方提供的原始训练数据包含除标签标准之外的其他训练数据时,可以在该参与方的原始数据中删除该类数据得到目标训练数据,亦即对原始训练数据进行数据过滤处理,以过滤掉无意义的训练数据,因此,其他类的训练数据是不参与此次的联邦分类建模任务的。并且,当参与方的所有原始训练数据均不在标签标准定义的类别内时,表明该参与方为此次联邦分类建模任务的无效参与方,需要退出此次的联邦分类建模任务。

由于机构D没有任何狗、猫、牛类图像数据,故退出参与方。但是,机构A、B、C一致同意将机构D作为协调方,且机构D答应作为协调方加入此次联邦分类建模任务中,因此,在此次训练过程中,A、B、C为参与方且A为任务发起方,D为协调方。之后,机构A、B根据标签标准将本方拥有的狗、猫、牛类图像分别标注为0、1、2,机构C根据标签标准将本方拥有的猫、牛类数据标注为1、2。

在步骤S1030中,参与方自定义神经网络结构与参数初始化方式。

参与方机构A、B、C分别根据本方的数据量和分类任务进行结构定义处理,以自定义神经网络结构。

图11示出了应用场景下参与方A的网络结构的结构示意图,如图11所示,参与方A的网络结构信息是该神经网络模型由1个卷积层、1个池化层和2个全连接层组成。

图12示出了应用场景下参与方B的网络结构的结构示意图,如图12所示,参与方B的网络结构信息是该神经网络模型由2个卷积层、2个池化层、3个全连接层组成。

图13示出了应用场景下参与方C的网络结构的结构示意图,如图13所示,参与方C的网络结构信息为该神经网络模型由2个卷积层、2个池化层、2个全连接层组成。

除此之外,图11、图12和图13所示的网络结构的输出层均有3个节点,输出的值分别表示图像样本是狗、猫和牛的概率。

在得到网络结构信息之后,还可以对网络结构信息进行参数初始化处理。该参数初始化处理是对神经网络模型的各个层进行权重设置,可以由每个参与方根据神经网络模型自行选择。

在步骤S1040中,参与方根据真实标签本地训练,蒸馏类别软标签。

参与方机构A、B和C分别根据本方拥有的每个类别的目标训练数据按4:1(当然,具体比例可以按照实际情况进行设定,此处不做限定)的比例划分成训练数据集和验证数据集,故A、B、C的训练集、验证集中每类样本数据如表3所示:

表3

参与方A、B、C将本方训练数据集按batch划分,如A、B、C方各自定义的一个训练batch的数据大小分别为32、256、128。

并且,发起方A确定联邦建模参数,如每一参与方的联邦训练轮数为10个epoch,软标签通信频率,即通信频率条件为每个epoch进行一次软标签蒸馏与通信。进一步的,发起方A将这些参数发送给参与方B、C和协调方D。

接着,参与方A、B、C各自根据本方拥有的训练数据集开始迭代训练本地的神经网络模型,以根据本地模型的输出向量与训练样本真实标签计算每次迭代的训练损失,将损失反向传播,更新本地模型的参数。

第一轮训练结束后,根据公式(2)对第一轮训练的输出向量和对应的温度参数进行知识蒸馏计算得到样本蒸馏向量。

然后,将标签为0的样本蒸馏向量求平均值得到狗类样本的软标签向量,将标签为1的样本蒸馏向量求平均值得到猫类样本的软标签向量,同样地,将标签为2的样本蒸馏向量求平均值,得到牛类样本的软标签向量。

举例而言,参与方A、B和C计算出的狗、猫、牛类的软标签向量分别为[0.6,0.3,0.1]、[0.4,0.5,0.1]、[0.2,0.2,0.6]。参与方A、B、C分别将每类样本的软标签向量发送给协调方D。

在步骤S1050中,判断所有参与方训练是否结束。

在每次训练结束之后,每个联邦分类建模参与方训练的结束情况进行判断。

在步骤S1060中,协调方联邦软标签计算。

当对参与方训练结束情况的判断为所有参与方未结束训练时,协调方可以计算各个参与方的联邦标签向量。

机构D收到参与方A、B、C第一轮发送来的每类样本的软标签向量后,分别计算参与方A、B、C的联邦标签向量,并将A、B、C联邦标签向量发送给对应的参与方A、B、C。

例如,其中参与方C的猫类联邦标签向量[0.35,0.55 0.1]为A、B方发送来的猫类标签向量[0.4,0.5,0.1]、[0.3,0.6,0.1]的平均值。

在步骤S1070中,参与方根据真实标签与联邦软标签进行本地训练;蒸馏类别软标签。

参与方A、B、C根据协调方D发送来的联邦标签向量以及本地数据开始下一轮训练。每个batch训练的损失值根据输出向量与真实标签之间的第一损失值以及输出向量与联邦标签向量之间的第二损失值两部分组成,更新神经网络模型的参数。

当下一轮数据训练完成后,重新蒸馏每类样本的软标签并将更新后的软标签发送给机构D。

重复步骤S1050-S1070三步,直到参与方均训练完10轮,即达到联邦训练轮数后训练结束,机构A、B、C均得到训练好的神经网络模型。

在该应用场景下的基于神经网络模型和知识蒸馏的跨样本联邦分类建模方法,在隐私保护方面,与集中式机器学习方法相比,参与方的数据均在本地参与神经网络模型的训练,无需向其他组织或机构透露本方的数据和标签信息,而且也只需向协调方传输每类软标签向量,且该软标签向量为一均值,不会泄露参与方训练数据的真实标签。总而言之,各个参与方以及参与方与协调方之间无需共享数据,保证了各个参与方的数据安全性。

在资源要求方面,与集中式机器学习方法相比,每个参与方在模型训练过程中只根据本地数据量进行训练,每一参与方的计算资源远远低于集中化训练所需的计算资源。另一方面,每个参与方只根据本方的训练数据和联邦标签向量进行模型的迭代更新,无需将所有参与方的训练数据进行集中化存储,存储资源要求也低于集中化的存储方式。进一步的,也无需将每一参与方的训练数据传输至其他组织或机构,各参与方与协调方通信的仅为每类的软标签向量,通信内容较小,大大降低了通信和传输成本。

在建模效果方面,各个参与方可以根据本方的训练数据量自行定义神经网络模型的复杂度,能够有效防止模型的过拟合和欠拟合情况发生。并且,每个参与方将其他参与方作为教师网络,当本地所拥有的训练数据量较少或网络结构较简单时,也可以学习到教师网络的分类能力,比仅仅使用本方数据训练神经网络模型的性能更优。而在参与方,对拥有的样本类别进行补全,例如缺少某一类样本时,也可以使该参与方学习到该缺失类别的分类能力,完整度更好。

在灵活性和机动性方面,各参与方不需要定义相同的神经网络结构,与其他联邦参数聚合方案相比,本方案更具有灵活性。并且,在联邦分类建模任务过程中,参与方可以随时申请加入或者退出本次联邦建模过程。

在训练速率和通信速率方面,与其他联邦参数聚合方案相比,该应用场景下每个参与方只需要在每轮训练结束时传输与接收每类联邦标签向量,且通信的内容不需要再进行额外的加密和解密计算,大大加速了神经网络模型的训练速率和通信速率。

通过本方案中共享的数据训练人工神经网络,能够在满足训练数据的通信和传输复杂度要求的基础上,同时保障人工神经网络的训练效果。

此外,在本公开的示例性实施例中,还提供一种基于神经网络和知识蒸馏的跨样本联邦分类建模装置。图14示出了基于神经网络和知识蒸馏的跨样本联邦分类建模装置的结构示意图,如图14所示,基于神经网络和知识蒸馏的跨样本联邦分类建模装置1400可以包括:模型定义模块1410、模型训练模块1420、向量发送模块1430和训练完成模块1440。

其中:

模型定义模块1410,被配置为获取待联邦分类建模任务制定的标签标准信息,并根据所述标签标准信息以及本地训练数据对所述联邦分类建模任务对本地神经网络模型进行结构自定义和参数初始化处理;

模型训练模块1420,被配置为对所述本地神经网络模型进行训练得到所述本地训练数据的预测标签向量,并对所述预测标签向量进行知识蒸馏处理得到本地每个类别的软标签向量;

向量发送模块1430,被配置为获取与所述联邦分类建模任务对应的联邦建模参数,并根据所述联邦建模参数将所述本地每个类别的软标签向量发送至协调方;

训练完成模块1440,被配置为接收所述协调方根据所述软标签向量返回的联邦标签向量,并根据所述联邦标签向量和所述本地训练数据对所述本地神经网络模型继续进行训练,以完成训练得到联邦分类建模模型。

在本发明的一种示例性实施例中,所述对所述本地神经网络模型进行训练得到所述本地训练数据的预测标签向量,包括:

获取用于训练所述联邦分类建模任务的原始训练数据,并对所述原始训练数据进行标签对齐处理和数据过滤处理,得到本地目标训练数据;

利用所述本地目标训练数据对所述本地神经网络模型进行训练,得到所述本地目标训练数据的预测标签向量。

在本发明的一种示例性实施例中,所述利用所述本地目标训练数据对所述本地神经网络模型进行训练得到所述本地目标训练数据的预测标签向量,包括:

对所述本地目标训练数据进行数据划分得到本地训练数据集,并对所述训练数据集进行数据划分得到多组待训练数据;

利用多组所述待训练数据对所述本地神经网络模型进行迭代训练,得到所述本地目标训练数据的预测标签向量。

在本发明的一种示例性实施例中,

所述对所述预测标签向量进行知识蒸馏处理得到本地每个类别的软标签向量,包括:

获取与知识蒸馏相关的温度参数,并对所述本地训练数据的预测标签向量和所述温度参数进行知识蒸馏计算,得到所述本地训练数据的预测标签的蒸馏向量;

对同类别的所述本地训练数据的预测标签的蒸馏向量进行平均值计算得到本地每个类别的软标签向量。

在本发明的一种示例性实施例中,

所述联邦建模参数包括联邦训练轮数,所述根据所述联邦标签向量和所述本地训练数据对所述本地神经网络模型继续进行训练,以完成训练得到联邦分类建模模型,包括:

获取与所述本地训练数据对应的标签数据,并对所述预测标签向量和所述标签数据进行损失计算,得到第一损失值;

对所述预测标签向量和所述联邦标签向量进行损失计算,得到第二损失值,并根据所述第一损失值和所述第二损失值对所述本地神经网络模型进行更新;

对更新后的所述本地神经网络模型继续进行训练,直至所述神经网络模型的训练次数达到所述联邦训练轮数时,得到训练好的联邦分类建模模型。

在本发明的一种示例性实施例中,

所述联邦建模参数包括通信频率条件;

所述根据所述联邦建模参数将所述本地每个类别的软标签向量发送至协调方,包括:

当对所述本地神经网络模型的训练过程满足所述通信频率条件时,将所述本地每个类别的软标签向量发送至协调方。

在本发明的一种示例性实施例中,

所述根据所述标签标准信息以及本地训练数据对所述本地神经网络模型进行结构自定义处理和参数初始化处理,包括:

确定所述本地神经网络模型的标准结构信息,并按照所述标准结构信息对所述本地神经网络模型进行结构自定义,得到所述本地神经网络模型的网络结构信息;

对所述本地神经网络模型进行参数初始化处理。

上述一种基于神经网络和知识蒸馏的跨样本联邦分类建模装置1400的具体细节已经在对应的神经网络模型的训练方法中进行了详细的描述,因此此处不再赘述。

应当注意,尽管在上文详细描述中提及了一种基于神经网络和知识蒸馏的跨样本联邦分类建模装置1400的若干模块或者单元,但是这种划分并非强制性的。实际上,根据本公开的实施方式,上文描述的两个或更多模块或者单元的特征和功能可以在一个模块或者单元中具体化。反之,上文描述的一个模块或者单元的特征和功能可以进一步划分为由多个模块或者单元来具体化。

此外,在本公开的示例性实施例中,还提供了一种能够实现上述方法的电子设备。

下面参照图15来描述根据本发明的这种实施例的电子设备1500。图15显示的电子设备1500仅仅是一个示例,不应对本发明实施例的功能和使用范围带来任何限制。

如图15所示,电子设备1500以通用计算设备的形式表现。电子设备1500的组件可以包括但不限于:上述至少一个处理单元1510、上述至少一个存储单元1520、连接不同系统组件(包括存储单元1520和处理单元1510)的总线1530、显示单元1540。

其中,所述存储单元存储有程序代码,所述程序代码可以被所述处理单元1510执行,使得所述处理单元1510执行本说明书上述“示例性方法”部分中描述的根据本发明各种示例性实施例的步骤。

存储单元1520可以包括易失性存储单元形式的可读介质,例如随机存取存储单元(RAM)1521和/或高速缓存存储单元1522,还可以进一步包括只读存储单元(ROM)1523。

存储单元1520还可以包括具有一组(至少一个)程序模块1525的程序/实用工具1524,这样的程序模块1525包括但不限于:操作系统、一个或者多个应用程序、其它程序模块以及程序数据,这些示例中的每一个或某种组合中可能包括网络环境的实现。

总线1530可以为表示几类总线结构中的一种或多种,包括存储单元总线或者存储单元控制器、外围总线、图形加速端口、处理单元或者使用多种总线结构中的任意总线结构的局域总线。

电子设备1500也可以与一个或多个外部设备1700(例如键盘、指向设备、蓝牙设备等)通信,还可与一个或者多个使得用户能与该电子设备1500交互的设备通信,和/或与使得该电子设备1500能与一个或多个其它计算设备进行通信的任何设备(例如路由器、调制解调器等等)通信。这种通信可以通过输入/输出(I/O)接口1550进行。并且,电子设备1500还可以通过网络适配器1560与一个或者多个网络(例如局域网(LAN),广域网(WAN)和/或公共网络,例如因特网)通信。如图所示,网络适配器1540通过总线1530与电子设备1500的其它模块通信。应当明白,尽管图中未示出,可以结合电子设备1500使用其它硬件和/或软件模块,包括但不限于:微代码、设备驱动器、冗余处理单元、外部磁盘驱动阵列、RAID系统、磁带驱动器以及数据备份存储系统等。

通过以上的实施例的描述,本领域的技术人员易于理解,这里描述的示例实施例可以通过软件实现,也可以通过软件结合必要的硬件的方式来实现。因此,根据本公开实施例的技术方案可以以软件产品的形式体现出来,该软件产品可以存储在一个非易失性存储介质(可以是CD-ROM,U盘,移动硬盘等)中或网络上,包括若干指令以使得一台计算设备(可以是个人计算机、服务器、终端装置、或者网络设备等)执行根据本公开实施例的方法。

在本公开的示例性实施例中,还提供了一种计算机可读存储介质,其上存储有能够实现本说明书上述方法的程序产品。在一些可能的实施例中,本发明的各个方面还可以实现为一种程序产品的形式,其包括程序代码,当所述程序产品在终端设备上运行时,所述程序代码用于使所述终端设备执行本说明书上述“示例性方法”部分中描述的根据本发明各种示例性实施例的步骤。

参考图16所示,描述了根据本发明的实施例的用于实现上述方法的程序产品1600,其可以采用便携式紧凑盘只读存储器(CD-ROM)并包括程序代码,并可以在终端设备,例如个人电脑上运行。然而,本发明的程序产品不限于此,在本文件中,可读存储介质可以是任何包含或存储程序的有形介质,该程序可以被指令执行系统、装置或者器件使用或者与其结合使用。

所述程序产品可以采用一个或多个可读介质的任意组合。可读介质可以是可读信号介质或者可读存储介质。可读存储介质例如可以为但不限于电、磁、光、电磁、红外线、或半导体的系统、装置或器件,或者任意以上的组合。可读存储介质的更具体的例子(非穷举的列表)包括:具有一个或多个导线的电连接、便携式盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦式可编程只读存储器(EPROM或闪存)、光纤、便携式紧凑盘只读存储器(CD-ROM)、光存储器件、磁存储器件、或者上述的任意合适的组合。

计算机可读信号介质可以包括在基带中或者作为载波一部分传播的数据信号,其中承载了可读程序代码。这种传播的数据信号可以采用多种形式,包括但不限于电磁信号、光信号或上述的任意合适的组合。可读信号介质还可以是可读存储介质以外的任何可读介质,该可读介质可以发送、传播或者传输用于由指令执行系统、装置或者器件使用或者与其结合使用的程序。

可读介质上包含的程序代码可以用任何适当的介质传输,包括但不限于无线、有线、光缆、RF等等,或者上述的任意合适的组合。

可以以一种或多种程序设计语言的任意组合来编写用于执行本发明操作的程序代码,所述程序设计语言包括面向对象的程序设计语言—诸如Java、C++等,还包括常规的过程式程序设计语言—诸如“C”语言或类似的程序设计语言。程序代码可以完全地在用户计算设备上执行、部分地在用户设备上执行、作为一个独立的软件包执行、部分在用户计算设备上部分在远程计算设备上执行、或者完全在远程计算设备或服务器上执行。在涉及远程计算设备的情形中,远程计算设备可以通过任意种类的网络,包括局域网(LAN)或广域网(WAN),连接到用户计算设备,或者,可以连接到外部计算设备(例如利用因特网服务提供商来通过因特网连接)。

本领域技术人员在考虑说明书及实践这里公开的发明后,将容易想到本公开的其他实施例。本申请旨在涵盖本公开的任何变型、用途或者适应性变化,这些变型、用途或者适应性变化遵循本公开的一般性原理并包括本公开未公开的本技术领域中的公知常识或惯用技术手段。说明书和实施例仅被视为示例性的,本公开的真正范围和精神由权利要求指出。

完整详细技术资料下载
上一篇:石墨接头机器人自动装卡簧、装栓机
下一篇:基于深度学习的非侵入负荷分解方法、系统、介质和设备

网友询问留言

已有0条留言

还没有人留言评论。精彩留言会获得点赞!

精彩留言,会给你点赞!

技术分类