对于一个新任务任务,有输入集合 $\left{\mathrm{x}_1, \mathrm{x}_2, \ldots, \mathrm{x}_N\right}$ ,和K个不同的prompt $\left{\left(r_x^{(1)}, r_y^{(1)}\right), \ldots,\left(r_x^{(K)}, r_y^{(K)}\right)\right}$ 。本文的目的是无监督学习函数 $f: \mathcal{X} \rightarrow \mathcal{Y}$ 。
无标签的输入在实践中经常出现,我们在文中考虑了两种这样的情况。
以前的方法使用一个额外的模块来扰乱每个例子,然后优化例子级的一致性,而我们建议优化提示级的一致性,这(1)在概念上是简单的,(2)可以缓解PLM的预测通常与同一任务的不同提示不一致的事实。我们建议将不同的提示对给定输入的预测正则化,使其相互接近,使用成对的蒸馏损失来使一个提示的预测更接近另一个提示的预测。
作者希望同一个数据的不同prompt,在infer的时候,能得到相同的结果。将这个想法作为约束, 就能在无监督数据上训练模型。损失函数为:损失函数定义为:
\[\begin{aligned} \mathcal{L}=-& \mathbb{E}_{\mathbf{x} \sim p_d(\mathbf{x})} \mathbb{E}_{r^{(i)}, r}(j) \sim p(r) \\ & \mathbb{E}_{\hat{\mathbf{y}} \sim \hat{q}\left(\mathbf{y} \mid \mathbf{x}, r^{(i)}\right)} \log p_\theta\left(r_y^{(j)}(\hat{\mathbf{y}}) \mid r_x^{(j)}(\mathbf{x})\right) \end{aligned}\]$p_d(\mathrm{x})$ 是输入数据的分布, $p(r)$ 是prompt的分布。
对于这个损失函数的理解:对于同一个数据,不同prompt都有一个对于标签对预测分布,该损失约束要求不同标签间的这个分布差异要尽量小。
在成对一致性损失中停止一方的梯度,也被证明有助于缓解所有输入导致相同预测的崩溃问题。
与传统的从教师模型提炼到学生模型的提炼(Hinton等人,2015),或以前的一致性训练,即一个教师提炼到几个学生(Clark等人,2018;Xie等人,2020a)不同,我们在一群提示中两两比较进行提炼,每个提示同时是教师和学生,因此我们把我们的方法称为swarm distillation。在我们的实施中,我们用k个随机抽样的提示对来近似期望值(r(i), r(j)),而不是列举所有的提示对来提高训练效率。
直接这样训练整个网络可能佘使得模型坍塌(collapse),即模型会进入对所有输入,所有 prompt都预测同一个标签的局部最优,来获得一个较高的一致性。
为了避免这个情况,作者采用LoRA模块,新增了一部分可训练参数,保持原网络fix。
具体来说,即在预训练好的transformer的每一层的FNN的weight $W$ 中,加入新的参数,将其变 为 $W+\alpha B A , W$ 不更新,只更新 $B$ 和 $A$ 。
停止策略
由于没有valid集,作者设计了一个训练停止条件,利用Fleiss Kappa指标来衡量预测的一致性。Fleiss'kappa表示,如果所有的提示都按照标签的边缘化分布进行预测,那么提示之间的一致程度就会超过预期的程度。 \(p_i=\frac{1}{K(K-1)} \sum_j n_{i j}\left(n_{i j}-1\right),\) $n_{i j}$ 是模型对第 $i$ 个数据的第 $j$ 标签的预测的prompt个数, $K$ 是prompt的个数。则 $p_i$ 反映了 对第 $i$ 个数据的预测的一致性, $p_i$ 越高,一致性越强。
\(\bar{P}_e=\sum_j q_j^2, \quad q_j=\frac{1}{N K} \sum_{i=1}^N n_{i j},\) $\bar{P}_e$ 反映了模型对所有数据中,标签分布的极性。若模型只预测一个标签, $\bar{P}_e$ 最大,此时即上述所说的坍塌。 Fleiss Kappa指标计算为: $\kappa=\frac{\bar{P}-\bar{P}_e}{1-\bar{P}_e}$ 作者在 $\kappa$ 开始单调递减是停止训练。
在数据集上表现SOTA
在本文中,我们探索了提示一致性正则化,使PLM成为更好的零样本学习者。我们的方法利用未标记的例子来达到零样本的收益。虽然我们以后适应的方式使用它来适应PLM与所提出的swarm distillation,但我们的正则化损失有可能与预训练阶段的预训练目标相结合,与多提示训练损失(Sanh等人,2022;Wei等人,2022)相结合,甚至与少样本学习设置的注释数据相结合。将swarm distillation与这些其他损失结合起来,可以很容易地绕过模型崩溃的问题,因为其他损失通常不鼓励崩溃的局部最优。 无监督swarm distillation法在序列生成任务上的潜在应用也值得在未来进行研究