论文阅读笔记 非自回归模型综述

题目 A Survey on Non-Autoregressive Generation for Neural Machine Translation and Beyond
论文链接 https://arxiv.org/pdf/2204.09269.pdf
作者列表 Yisheng Xiao*, Lijun Wu*, Junliang Guo, Juntao Li, Min Zhang, Tao Qin, Senior Member, IEEE and Tie-Yan Liu, Fellow, IEEE
作者单位 微软亚研院,苏州大学
文章类型 长文

简介

非自回归(NAR)生成首先用于加速翻译,称为NAT模型,缓解AT模型O(n)复杂度的问题,但这是以精度为代价的。

文章outline:

总览

优化目标

AT模型的优化目标:

(完全的)NAT模型的优化目标:

区别有二:第t个token的解码不取决于y<t的序列,且在推理阶段,句子长度T是需要预测的。

半自回归SAT模型的优化目标:

其中k是在一个时间步上并行生成的token数量。当k=1和k=T时,退化为AT和NAT模型。

基于迭代的NAT模型优化目标:

每一步的解码以上一步的生成句子为基础,继续优化。

模型架构

编码器基本没区别,主要是解码器。

NAT模型在训练时不再需要掩码矩阵。当然SAT模型仍然需要掩盖后续组的token。

主要挑战

NAT模型最关键的问题是,编码器的输入只能依赖源端的信息,作者称之为:“fail to capture the target dependency of target tokens”。

这会导致以下问题:

  1. 每一个target token都是独立预测的,模型可能选择不同合理翻译的错误组合,导致错误。
  2. 过度翻译和翻译不足,就是source token某些被翻译了多次,某些没被翻译到。

改进方法

  1. 训练数据改进
  2. 模型层面的改进
  3. 训练目标改进
  4. 解码策略的改进
  5. 借助预训练模型

数据改进

知识蒸馏(KD)

先用AT模型对训练集的句子解码得到目标句子,可以用贪婪、束搜索等方法,然后得到蒸馏后的数据集,用传统的负对数似然损失对NAT训练。

但是也发现不是更好的教师模型能教出更好的学生模型,他们的能力应该保持一致。知识蒸馏的优点是减少预测目标的依赖性、减少词汇多样性和token重排程度,有利于NAT学习source和target的对齐。

缺点是低频词的翻译准确性会降低,也有一些处理方法。

数据学习策略

主要是课程式学习,从难到易的学习。例如逐步将解码器的输入从AT模式切换到NAT模式。例如在训练的时候,对于decoder输入,第一步先用encoder输入,之后计算decoder输出和golden truth的距离,在golden truth中随机采样(采样数正比于距离)替代上一次的decoder输入,这样可以根据模型的能力逐步调整学习片段。另外还有从单词、短语、句子级别的课程式学习等等。

模型改进

基于迭代

插入和删除

基于平衡二叉树的生成顺序,就是同时预测位置和token,再预测子树的单词,就是insertion-transformer。也有基于删除的操作。

掩码和预测

在推理时,每次迭代时低预测概率的token会被mask以供下一次迭代。这种mask的选择也有改进工作:如源句子相应的oken也被联合掩盖;使用自回归的解码器(和原来的解码器共享参数)来预测需要mask的token;使用更高级的策略选择mask等。

基于潜变量

引入了潜变量Z,似然变为:

也就是先生成潜变量,再生成句子。其实就是decoder的输入如何优化。

文章简单介绍了很多方法,没办法一篇篇点进去看,下面两种是有图的:

第一种就是利用对齐信息,decoder的输入不再是简单的encoder输入,而是预测encoder输入中哪些单词如何与target对齐,再以此为decoder输入。

第二种就是重排,这个图有错误的,Reorder的输出是伪翻译,即source语言,但是顺序是以target语言为基础。

其他

  1. 使用短语表查找和嵌入映射,作为decoder的输入,使得其更接近目标句子
  2. 对中间状态进行监督。例如使用额外的注意力块来增强decoder层之间、以及生成句子的邻居token之间的关系。
  3. 对解码器的输出改进,例如惩罚相似的hidden state等,没有点进文章细看。

训练目标

最基础的是交叉熵

问题就是正确翻译的顺序可能有多种,而NAT预测的token没有相互依赖,交叉熵会惩罚那些正确翻译但顺序和golden truth不一致的翻译。

所以改成基于N-gram的损失函数。

另外CTC和基于顺序的损失函数就似懂非懂,干看综述里是很不详细的,要点进去细看要花很久,所以没有看。

解码

长度预测

NAT必须事先预测句子的长度,有以下方法:1. 建模长度预测为分类任务;2. 建模为线性回归;3. 引入[LENGTH]标记直接得到长度预测;4. 基于CTC的建模

改进的方法有:在预测长度后,选择一个范围而不是只使用这个长度,用这个范围预测并采用最佳翻译;或者直接预测多个可能的长度。

解码策略

半自回归解码

基于语法标签:生成语法参考树,一个步骤解码一组有语法关系的token

恢复机制:分段生成,每个段中的token是自回归生成的,基于当前所有已生成的token(包括其他段),引入DEL标记以删除重复的段。每个段起始都是BOS,直至遇到EOS或DEL停止。

Aggressive decoding:先用NAT生成草稿,然后用AT模型验证(验证时使用一些准则,如:在AT模型预测的top1位置,或者在AT模型预测的topk内且对数似然的差值在一定范围内),找到第一个不符合AT模型验证的token,把后面去掉,重新让NAT预测后面的句子。反复迭代。

迭代解码

掩码预测(上文已经介绍,有easy-first等)、插入和删除、重写机制(带修改器和定位器的机制,定位器用于确定哪个单词要保留/修改,需要修改的mask掉让修改器预测;还有一个应用波束搜索的,看不太懂)

混合解码

例如Diformer可以在推断时(每个时间步)动态选择L2R、R2L或迭代式NAR解码。

利用预训练模型

  1. 使用AT监督训练:这些方法大致是要求训练时NAT隐藏层的状态也向AT模型靠拢。
  2. 上文提到的课程式训练
  3. 联合AT模型一起训练:1. 用NAT模型的输出迭代更新AT,用AT的新输出训练NAT模型;2. 共享encoder,然后一起训练,一起计算损失。

NMT外的任务

我只看了摘要的任务,和NMT相比优势是目标句子的信息已经显式或隐式地在长文本输入中了。只有一篇论文:https://arxiv.org/abs/2205.14521,感觉有点复杂还没细看。

发展趋势

可以看到完全NAT已经越来越好了,快赶上迭代了。


论文阅读笔记 非自回归模型综述
https://bebr2.com/2022/11/16/论文阅读笔记 非自回归模型综述/
作者
BeBr2
发布于
2022年11月16日
许可协议