3.【新神经网络架构】Transformer 论文精读

约 3090 字大约 10 分钟...

3.【新神经网络架构】Transformer 论文精读

原文链接:Attention Is All You Needopen in new window 源码链接:https://github.com/tensorflow/tensor2tensoropen in new window

0. 核心总结

核心是提出了一种继CNN、RNN之后的新型神经网络架构Transformer,这种架构完全基于注意力机制,可以学习到更底层特征,在机器翻译上取得了SOTA的效果,并且目前对图像、音频、视频等领域也影响深远。

1. 摘要

本文提出了一个全新的网络架构,称为Transformer,完全基于注意力机制,完全摒弃了卷积和循环神经网络。通过多个机器翻译实验表明,本文的Transformer模型在精度上更优,同时可以并行化计算和需要的训练时间更少。在WMT2014英德翻译任务上取得了28.4个BLEU,比包括集成模型在内的最好结果高了2个BLEU以上。

2. 引言

RNN、LSTM、GRU是此前最为先进的序列模型,但是由于其顺序计算的特性,无法并行化。注意力机制(Attention)的提出,使得可以在不同位置之间建立长距离的依赖关系,从而可以并行化计算,通常这类注意力机制都与循环网络结合使用。

在本文的工作中,提出了一种Transformer结构,完全依靠注意力机制来绘制输入和输出之间的全局依赖关系。Transformer支持更多的并行化计算,在8个P100 GPU上训练短短12个小时后就可以达到翻译质量的先进水平。

3. Transformer模型结构

Transformer模型的整体结构如下图所示,编码器和解码器都使用堆叠的自注意力(Self-attention)逐点全连接层(Point-wise fully connected layers),分别如图的左半部分和右半部分所示:

Transformer模型结构
Transformer模型结构

大多数先进的序列模型都有编码器-解码器结构,在Transformer中,编码器将输入序列 (x1,,xn)(x_1,\ldots,x_n) 映射为一个连续的表示 z=(z1,,zn)\pmb{z}=(z_1,\ldots,z_n)

给定 z\pmb{z} ,解码器逐个生成输出序列 (y1,,ym)(y_1,\ldots,y_m) ,一次一个元素,在每一次,模型都是自回归的,即在生成下一个元素 yiy_i 时,模型都是输入之前生成的元素 y1,,yi1y_1,\ldots,y_{i-1} 作为额外的输入。

3.1 编码器和解码器块细节

  • 编码器(Encoder):编码器由 N=6N=6 个相同的层堆叠而成,每层包含两个子层。第一层是一个多头自注意力机制(Multi-head self-attention),第二层是一个简单的逐点全连接前馈网络(Point-wise fully connected feed-forward network)。 每个子层都有一个残差连接(Residual Connection),然后进行层归一化(Layer Normalization),即每个子层的输出是 LayerNorm(x+Sublayer(x))\mathrm{LayerNorm}(x+\mathrm{Sublayer}(x)) 。为了便于进行残差连接,模型的所有子层和嵌入层输出的维度都是 dmodel=512d_{\text{model}}=512

  • 解码器(Decoder):解码器同样也由 N=6N=6 个相同的层堆叠而成。解码器的结构基本与编码器相同,包含编码器中的两个子层,但解码器额外有的第三个子层:掩码多头注意力机制(Masked Multi-head Attention)。这个掩码是为了在训练和验证时不会看到未来的信息,即在生成第 ii 个词时,只能看到前 i1i-1 个词,而不能看到第 ii 个词及其后面的词。

  • 层归一化(Layer Normalization):层归一化和批归一化(Batch Normalization)类似,但是批归一化是在一个batch中对每个样本的同一个特征进行归一化,而层归一化是对每个样本的不同特征进行归一化,如下图所示:

层归一化示意图
层归一化示意图

3.2 注意力机制细节

一个注意力函数可以将一个 查询(Query) 和一组 键(Key)值(Value) 对映射到一个输出,输出为值的加权和,权重由查询与对应的键的相似度决定。

3.2.1 缩放点积注意力(Scaled Dot-Product Attention)

本文中的注意力函数是缩放点积注意力(Scaled Dot-Product Attention),如下图所示:

缩放点注意力计算示意图
缩放点注意力计算示意图

其中,查询和键的维度为 dkd_k ,值的维度为 dvd_v 。查询和键的矩阵乘法得到的矩阵除以 dk\sqrt{d_k} ,然后进行softmax操作得到权重矩阵,最后与值的矩阵乘法得到输出矩阵,计算公式如下:

Attention(Q,K,V)=softmax(QKTdk)V \text{Attention}(Q,K,V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V

3.2.2 多头注意力(Multi-head Attention)

与单独使用注意力函数相比,使用多头注意力函数将查询、键和值线性投影到 hh 个不同的 dkd_kdkd_kdvd_v 个维度是有益的。我们还可以并行计算这 hh 个头,生成 dvd_v 维度的输出值,然后将这些值拼接起来再次进行线性投影,得到最终的输出值,如下图所示:

多头注意力示意图
多头注意力示意图

多头注意力允许模型在不同位置共同关注来自不同表示子空间的信息,在单个注意力头中,模型可能会忽略来自其他子空间的相关信息。

多头注意力机制的计算公式如下:

MultiHead(Q,K,V)=Concat(head1,,headh)WOwhere headi=Attention(QWiQ,KWiK,VWiV) \begin{aligned} \text{MultiHead}(Q,K,V) &= \text{Concat}(\text{head}_1,\ldots,\text{head}_h)W^O \\ \text{where head}_i &= \text{Attention}(QW_i^Q,KW_i^K,VW_i^V) \end{aligned}

其中,投影为参数矩阵,它们的维度为 WiQRdmodel×dkW_i^Q\in\mathbb{R}^{d_{\text{model}}\times d_k}WiKRdmodel×dkW_i^K\in\mathbb{R}^{d_{\text{model}}\times d_k}WiVRdmodel×dvW_i^V\in\mathbb{R}^{d_{\text{model}}\times d_v}WORhdv×dmodelW^O\in\mathbb{R}^{hd_v\times d_{\text{model}}}

在本文实验中,采用 h=8h=8 个头,dk=dv=dmodel/h=64d_k=d_v=d_{\text{model}}/h=64 。由于每个头的维度减小,总的计算成本与单头注意力机制相似。

3.3 逐点前馈网络细节

逐点前馈网络分别且相同地应用于输入地每个位置,这包含两个线性变换,中间有一个 ReLU 激活函数,计算公式如下:

FFN(x)=max(0,xW1+b1)W2+b2 \text{FFN}(x)=\max(0,xW_1+b_1)W_2+b_2

其中,输入输出的维度数均为 dmodel=512d_{\text{model}}=512 ,中间层的维度数为 dff=2048d_{ff}=2048

3.4 位置编码细节

不同于 RNN 和 CNN,Transformer 模型中没有显式的循环和卷积,因此模型需要一种方法来利用序列中单词的顺序信息。为了在序列中的单词中注入一些位置信息,本文采用了位置编码(Positional Encoding)。

位置编码与词嵌入具有相同的维度 dmodeld_{\text{model}},因此最终输出为两者相加。位置编码的值是根据位置的正弦和余弦函数计算得到的,如下式:

PE(pos,2i)=sin(pos/100002i/dmodel)PE(pos,2i+1)=cos(pos/100002i/dmodel) \begin{aligned} PE_{(pos,2i)} &= \sin(pos/10000^{2i/d_{\text{model}}}) \\ PE_{(pos,2i+1)} &= \cos(pos/10000^{2i/d_{\text{model}}}) \end{aligned}

其中,pospos 是位置,ii 是维度。也就是说,位置编码的每个维度对应一个正弦曲线,波长逐渐增加。本文还尝试了其他编码方式,但是发现这种方式的效果最好。

4. 为什么使用自注意力

本章将自注意力机制与 RNN 和 CNN 在三个指标上进行了比较。一是每层的计算复杂度;二是可并行的计算量,用所需的最小顺序操作数来衡量;三是最长路径内的依赖关系,用最大路径长度来衡量。

对比结果如下表所示,其中 nn 为序列长度,dd 为每个位置的表示维度,kk 为卷积核的大小,rr 为受限自注意力中的领域大小:

不同层类型的对比
不同层类型的对比

5. 训练细节

5.1 训练数据集和批大小

本文使用了标准的 WMT 2014 英语到德语的数据集,共包含 4.5M 句子对。句子被编码成约有 37000 tokens 的共享源-目标词汇表。

对于英语到法语,本文使用了更大的 WMT 2014 英语到法语的数据集,共包含 36M 句子对,拆分成 32000 tokens 词块词汇表。具有近似序列长度的数据被分配到同一批中,每个批包含约 25000 个源 tokens 和 25000 个目标 tokens。

5.2 训练硬件和训练计划

本文使用了 8 块 NVIDIA P100 GPU 进行训练。对于基础模型,每步大约需要0.4秒,共训练了10万步(12小时)。对于大模型,每步时长约为1秒,共训练了30万步(3.5天)。

5.3 优化器

本文使用 Adam 优化器,其中 β1=0.9\beta_1=0.9β2=0.98\beta_2=0.98ϵ=109\epsilon=10^{-9}。对于学习率,本文按以下公式进行变化:

lrate=dmodel0.5min(step_num0.5,step_numwarmup_steps1.5) lrate = d_{\text{model}}^{-0.5}\cdot\min(step\_num^{-0.5},step\_num\cdot warmup\_steps^{-1.5})

其中,step_numstep\_num 为当前训练步数,warmup_steps=4000warmup\_steps=4000

5.4 正则化

在训练过程中,本文使用三类正则化方法:

  • 残差Dropout:本文将Dropout引用到每个子层的输出,然后再进行残差连接和层归一化。Dropout的概率为 Pdrop=0.1P_{drop}=0.1

  • 标签平滑(Label Smoothing):本文在训练过程中使用了 ϵls=0.1\epsilon_{ls}=0.1 的标签平滑。标签平滑的作用是让模型对于训练数据中的噪声更加鲁棒。

6. 实验结果

6.1 机器翻译结果

下表总结了本文提出的Transformer架构在 WMT 2014 英语到德语和英语到法语数据集上与其它模型架构的结果:

Transformer与其它模型架构的对比结果
Transformer与其它模型架构的对比结果

6.2 模型参数分析

为了评估Transformer的不同参数的重要性,本文在 newstest2013 英德翻译数据集上进行了一系列实验。下表总结了实验结果:

Transformer参数分析的结果
Transformer参数分析的结果
  1. 在(A)中,改变了注意力头的数量和键值维度,结果表明,适量地增加头的数量可以提高性能,但太多的头也会降低性能。
  2. 在(B)中,可以发现降低键的维度会降低性能。
  3. 在(C)中,可以发现增加模型深度,即更大的模型可以提高性能。
  4. 在(D)中,可以发现使用正则化方法可以提高性能。
  5. 在(E)中,将正弦位置编码替换为可学习的位置嵌入,可以发现实际性能与基模型几乎相同。

7. 个人思考

本文的最大贡献就是提出了一种全新的神经网络架构Transformer,该架构影响深远,目前也已经广泛应用于图像、音频和视频等领域。

Transformer的核心注意力机制,顾名思义类似于人的注意力,查询、键和值分别可以认为是人的注意力上的自主性提示、非自主性提示和感官输入值。打个比方就是,当你在看一幅画时,你的主观意识会自主地注意画中的某个部分,比如一只猫,这就是查询;而当你看到画中的一只猫时,你的注意力会被非自主地吸引到猫的周围,比如猫的眼睛、耳朵、鼻子等,这就是键;最后,你会将这些感官输入值进行整合,形成对这幅画的理解,这就是值。

注意力机制乍一看和多层感知机MLP很像,无非就是权重和输入的加权和,但是注意力机制的权重是由输入决定的,而MLP的权重是固定的。这也是为什么注意力机制可以用于序列建模的原因,因为序列中的每个位置的输入都是不同的,所以注意力机制可以根据不同的输入来决定权重。

Transformer还有一个很大的特点是,Transformer具有很大的网络容量,需要学习很大的数据量才能学习到底层特征,更适合大规模数据集。这是由于Transformer结构缺少一些而类似CNN先天的归纳偏置,即卷积结构自带的先验知识,例如平移不变性和包含局部关系,因此在规模不足的数据集上表现没有那么好,需要更多的数据来学习到某种特征。

上次编辑于:
贡献者: lisenjie757
已到达文章底部,欢迎留言、表情互动~
  • 赞一个
    0
    赞一个
  • 支持下
    0
    支持下
  • 有点酷
    0
    有点酷
  • 啥玩意
    0
    啥玩意
  • 看不懂
    0
    看不懂
评论
  • 按正序
  • 按倒序
  • 按热度
Powered by Waline v2.14.9