深入剖析Transformer:从原理到应用

本文图文并茂地解析了Transformer模型,深入了解其内部机制,核心概念与应用,助力读者理解并应用该模型。

原文标题:独家 | 用图逐层解析Transformer

原文作者:数据派THU

冷月清谈:

本文以图文并茂的方式,逐层解析了Transformer模型,从其基本概念、内部结构到训练和应用过程进行了详细阐述。文章通过对模型中关键环节的张量维度、数据流动等进行可视化展示,帮助读者理解多头注意力机制、掩码机制等核心概念。同时,文章还探讨了Transformer在机器翻译任务中的应用,并通过一个简单的例子展示了如何使用训练好的模型进行推理。总而言之,本文旨在降低Transformer的学习门槛,使更多人能够理解和应用这一强大的深度学习模型。

怜星夜思:

1、Transformer模型中的多头注意力机制是如何提升模型性能的?具体来说,不同的“头”可能关注到输入序列中的哪些不同信息?
2、文章中提到,在训练过程中使用了掩码(mask)机制。那么,除了防止模型“作弊”外,掩码机制还有哪些其他的应用场景?
3、文章中提到Beam Search(束搜索),这是一种常用的解码策略。你觉得除了Beam Search,还有哪些其他的解码策略?它们各自的优缺点是什么?

原文内容

文:Eric Silberstein
翻译:周梓溢
校对:陈超‍‍
本文约6200字,建议阅读10+分钟
你到底输入了什么,你到底得到了什么,你是如何用Transformer生成文本的?


用图逐层解析Transformer

上周我在Acquired上收听了一集英伟达的播客节目。这一集谈到了Transformer:GPT 中的 T 和 21 世纪最大发明的候选人。

在 Beacon街道上边走边听,我在想,我是了解Transformer的,对吧?Transformer在训练过程中掩蔽了Token,让注意力头将文本中的概念联系起来,以此来预测下一个单词的概率。我已经从 Hugging Face 下载大语言模型并尝试使用。早些时候,当我在使用 GPT-3时,它的“聊天”部分还被完全开发出来。在 Klaviyo,我们甚至在我们的主题词助手中构建了第一批由 GPT 驱动的生成式AI功能。很久以前,我还开发过一个由老式语言模型驱动的语法检查器。Transformer也许是这样吧。

Transformer 是 Google 的一个团队提出的,这个团队致力于自动翻译,例如从英语到德语。它于 2017 年著作 Attention Is All You Need 中被介绍给全世界。我调出论文,看了看图 1(下图所示):

图 1 来自 Attention Is All You Need

如果我能理解的话,那只是在最模糊的层面。我越看图表、读论文,就越意识到我没有了解细节。以下是我写下的几个问题:

  • 在训练期间,输入是英语的标记化句子,输出是德语的标记化句子吗?
  • 单个训练批次中的每个项目到底是什么?
  • 为什么要将模型输出的结果喂给模型,“带有掩码的多头注意力 ”如何足以防止它通过从输出的结果中学习输出来作弊?
  • 究竟什么是多头注意力?
  • 损失究竟是如何计算的?不可能是把一个源语言句子翻译完整,然后计算损失,这不合理。
  • 训练后,究竟喂进去什么才能生成翻译?
  • 为什么有三个箭头指向多头注意力模块?

我敢肯定,这些问题对两类人来说很容易而且听起来很幼稚。第一种是已经使用相似模型(例如 RNN、编码器-解码器)来做相似事情的人。当他们阅读论文时,一定立刻就明白 Google 团队完成了什么以及他们是如何做到的。第二个是更多的人意识到Transformers 在过去七年中的重要性,并花时间了解细节。

我想学习,我认为最好的方法是从头开始构建模型。我很快就迷失了方向,并决定跟着别人写的代码学习。我找到了这个很棒的笔记本,它解释了这篇论文并在 PyTorch 中实现了这个模型。我复制了代码并训练了模型。我把所有东西(输入、Batch、词汇、维度)都保持得很小,这样我就可以追踪每一步发生的事情。我发现,注意图表上的维度和张量有助于我保持理顺思路。当我完成时,我对上述所有问题都有了很好的答案,我将在解析完图表之后回答它们。

以下是我的笔记整理版本。这部分的所有内容都是为了训练一个单一的小批量,这意味着不同图表中的所有张量都是一起的。

为了便于理解并从 Notebook 中复制想法,我们将训练模型来复制 Token。例如,经过训练,“dog run”应该翻译成“dog run”。


换句话说:


下面,我们尝试用语言来解释到目前为止图上的张量维度(紫色显示)的含义:


其中一个超参数是 d-model,在论文的基本模型中一般是 512。在这个例子中,我把它设为 8。这意味着我们的嵌入向量的长度为 8。下面是主图,很多地方都标注了维度:


让我们放大编码器的输入:


图中展示的大部分模块(如残差连接与归一化、前馈网络、最终线性变换)仅作用于最后一个维度(8)。如果只是做这些处理,那么模型将只能使用序列中单个位置的信息来预测单个位置。在某个地方,存在某种机制让不同位置的信息“混合”,而这种奇迹发生在多头注意力模块中。

让我们放大编码器中的多头注意力模块。对于下一张图,请记住在我的示例中,我将超参数 h(头数)设置为 2。(在论文的基本模型中,它是 8)。

图 2 来自 Attention Is All You Need,作者提供注释

(2,3,8) 是怎么变成 (2,2,3,4) 的?我们进行了线性变换,然后获取结果并将其拆分为头数 (8 / 2 = 4) 并重新排列张量维度,这样我们的第二个维度是头部。让我们看看一些实际的张量:


我们还没有做任何在不同位置之间混合信息的工作。下一步,我们将在缩放点积注意模块中实现这一点。“4”维度和“3”维度最终将相接。

图 2 来自 Attention Is All You Need,作者提供注释

让我们看看张量,但为了更容易理解,我们只查看batch中的第一项和第一个头数。换句话说,就是Q[0,0]、K[0,0] 等。其他三个注意力头也会发生同样的情况。


让我们看看softmax 和 V 的输出之间的最终矩阵乘法:


从一开始,我们就可以看到,在进行乘法运算之前,从我们最初的句子“dog run ”开始,V 中的三个位置都是独立运算的。这次乘法第一次融合了来自其他位置的信息。

回到多头注意力图,我们可以看到 concat 将每个头的输出重新放在一起,因此每个位置现在都由长度为 8 的向量表示。请注意,在 concat 之后但线性之前的张量中的 1.8 和 -1.1 与向量中前两个元素的 1.8 和 -1.1 匹配,这些元素来自上面所示的缩放点积注意的输出中批处理中第一项中第一个头的第一个位置。(接下来的两个数字也匹配,但它们被省略号隐藏了。


现在让我们缩小到整个编码器:


起初,我认为我会详细跟踪前馈块。它在论文中被称为“位置前馈网络”,我认为这意味着它可能会将信息从一个位置带到它右侧的位置。 然而 ,事实并非如此。“位置方面” 意味着它在每个位置上独立运行。它对从 8 个元素到 32 个元素的每个位置进行线性变换,执行 ReLU(线性整流函数,最大值为 0 和数字),然后进行另一次线性变换以返回到 8。(这是我们的小例子。在论文的基本模型中,它从 512 到 2048,然后回到 512。这里有很多参数,可能这就是很多学习发生的地方!)前馈的输出返回到 (2,3,8)。

暂且先不谈我们的玩具模型,下面是论文中基本模型中编码器的样子。输入和输出尺寸一致,这一点非常好!


现在让我们完全缩小,以便查看解码器。


我们不需要跟踪大部分解码器端,因为它与我们刚刚在编码器端看到的非常相似。但是,我标记为A 和 B 的两部分是不同的。A 部分之所以不同,是因为我们进行掩蔽多头关注。这一定是在训练时不会 “作弊 ”的神奇之处。B 部分我们稍后再来讨论。不过,首先让我们隐藏内部细节,并牢记我们希望从解码器中得到什么。


为了真正说明这一点,假设我们的英语句子是“she pet the dog”,而我们翻译的黑话句子是 “eshay etpay ethay ogday”。如果模型有 “eshay etpay ethay” 并试图想出下一个词,那么 “ogday” 和 “atcay” 都是高概率的选择。考虑到 “she pet the dog” 这个完整的英文句子的上下文,它真的应该选择 “ogday”。但是,如果模型在训练期间可以看到 “ogday”,则它不需要学习如何使用上下文进行预测,它只需学习复制即可。

让我们看看掩码是如何做到这一点的。我们可以跳过前面这部分,因为 A 的第一部分的工作方式与之前完全相同,它用线性转换并将内容拆分为多个头。唯一的区别是进入缩放的点积注意力部分的维度是 (2,2,2,4) 而不是 (2,2,3,4),因为我们的原始输入序列的长度为 2。下面是缩放的点积注意力部分。正如我们在编码器端所做的那样,我们只关注批处理中的第一项和第一个头。


这里我们有一个掩码。让我们看看 softmax 和 V 的输出之间的最终矩阵乘法:


现在,我们看看B部分,解码器中的第二个多头注意力。与其他两个多头注意力块不同,我们没有输入三个相同的张量,所以我们需要考虑什么是 V,什么是 K,什么是 Q。我用红色标记了输入。我们可以看到编码器的输出的 V 和 K 维度都是 (2,3,8)。Q 的维度是 (2,2,8)。


和以前一样,我们关注缩放的点积注意力部分。V 和 K 的维度为 (2,2,3,4) – batch中的两个项目,两个头,三个位置,长度为 4 的向量,而 Q 的维度为 (2,2,2,4),这是有道理的,但也令人困惑。


即使我们正在“读取” 编码器输出的结果,其中 “序列” 长度为 3,但不知为什么,所有的矩阵数学运算都成功了,我们最终得到了我们想要的维度 (2,2,2,4)。让我们看看最后的矩阵乘法:


每个多头注意力块的输出相加。让我们回到前面,看看解码器的输出,并将其转换为预测值:


线性变换将我们从 (2,2,8) 带到 (2,2,5)。 可以将其视为反向嵌入,不同之处在于我们不是从长度为 8 的向量到单个Token的整数标识符,而是转到 5 个Token的概率分布。我们小例子中的数字使这看起来有点滑稽。在论文中,当他们把英语翻译成德语时,这更像是从 512 大小的向量变成 37,000 个词汇。

稍后我们将计算损失。不过,首先,即使只是看一眼,您也可以了解模型的运行情况。


它答对了一个Token。这并不奇怪,因为这只是我们的第一个训练Batch,而且都是随机的。这张图的一个好处是,它清楚地表明这是一个多分类问题。这些类是词汇量(在本例中为 5 个类), 这就是我之前感到困惑的地方,我们对翻译句子中的每个Token进行预测(和评分),而不是每个句子一个预测。我们来做实际的损失计算。


例如,如果-3.2 变为 -2.2,我们的损失将减少到 5.7,朝着所需的方向移动,因为我们希望模型了解第一个Token的正确预测是 4。

上图没有对标签做平滑处理。在实际论文中,损失计算对标签进行平滑处理,并使用 KL Divergence 损失。我认为在没有平滑的情况下,这与交叉熵相同或类似。这是与上面相同的图表,但使用了标签平滑。


我们还可以快速了解一下编码器和解码器中正在学习的参数数量:


作为合理性检查,我们的玩具模型中的前馈块具有从8 到 32 再回到 8 的线性变换(如上所述),因此 8 32(权重)+ 32(偏差)+ 32 8(权重)+ 8(偏差)= 52。请记住,在本文的基本模型中,d-model 为 512,d-ff 为 2048,并且有 6 个编码器和 6 个解码器,将有更多的参数。

使用训练过的模型


现在让我们看看如何输入源语言文本并输出翻译的文本。我在这里仍然使用一个玩具模型,这个模型通过应对Token来“翻译”进行训练,但和上面的示例不同,这个模型使用的词汇量是11, 而d-model 为 512。(上面我们的例子中词汇量为5,d-model 是8。)

首先,我们来做一个翻译,看看它是怎样工作的。


第一步是将源句子输入编码器并保留其输出,在本例中为维度为 (1, 10, 512) 的张量。


第二步是将输出的第一个Token 输入到解码器中,并预测第二个 Token。我们知道第一个 Token,因为它总是等于1。


在论文中,他们使用的是束搜索,束大小为4,这意味着我们将在这一点上考虑 4 个最高概率的Token。为了简单起见,我将改用贪心算法。您可以将其视为束大小为 1 的束搜索。因此,从图的顶部开始,最高概率的Token是数字 5。(上面的输出是概率的对数。最高概率仍然是最大的数字。在本例中,它是 -0.0,实际上是 -0.004,但我只显示了一个小数位。该模型非常有信心 5 是正确的!exp(-0.004) = 99.6%)

现在我们将 [1,5] 输入到解码器中。(如果我们使用束大小为 2 进行束搜索,我们可以改为输入包含 [1,5] 和 [1,4] 的Batch,这是下一个最有可能的Batch。)


现在我们输入[1,5,4]:


依此类推,直到我们得到一个表示句子结尾的Token(在我们的示例词汇表中不存在)或达到最大长度。


回到上面的问题


现在,我基本上可以回答我最开始的问题了。

在训练期间,输入是英语的标记化句子,输出是德语的标记化句子吗?


是的,或多或少。

训练Batch中的每个项目到底是什么?


每个项目对应于一对翻译的句子。

  • 项目的 “x” 有两个部分。第一部分是源句子的所有Token。第二部分是目标句子的所有Token,除了最后一部分。
  • 项目的 “y” (label) 是目标句子的所有Token,除了第一个Token。由于源句子和目标句子的第一个 Token 始终是相同的,因此我们不会浪费或丢失任何训练数据。


有些微妙的是,如果这是一个分类任务,比如模型必须拍一张图并输出一个类(房子、汽车、兔子等),我们会将Batch中的每个项目看作是为损失计算贡献一个“分类”。但是,在这里,Batch中的每个项目都将为损失计算贡献(number_of_Tokens_in_target_sentence – 1) 个“分类”。

为什么要将模型输出的结果输入到模型中,“带有掩码的多头注意力 ”如何足以防止它通过从输出的结果中学习输出来作弊?

如果您将模型输出的结果输入到模型中去,模型可以学习基于源句子的含义和到已经翻译的单词来预测翻译的结果。尽管模型中做了很多事情,但信息在位置之间移动的唯一时间是在注意力步骤上。即使我们将翻译后的句子输入到解码器中,但第一次注意力计算会使用掩码将所有超出我们预测之外的位置信息归零。

究竟什么是多头注意力?

我可能要问一下究竟什么是注意力,因为这是更核心的概念。多头注意力的意思是把向量切分成组,关注这些组,然后将这些组重新放在一起。例如,如果向量的大小为512 并且有 8 个头,注意力会独立地对 8 组进行关注,每组包含一整批完整位置,每个位置都有一个大小为 64 的向量。如果你眯着眼睛,你可以看到每个头最终是怎样学会关注某些相关概念的,就像可视化一样展示,一个头将如何学习一个代词所指的是什么单词。

损失究竟是如何计算的?不可能是把一个源语言句子翻译完整,然后计算损失,这不合理。

对。我们不会一次性翻译一个完整的句子并计算整体句子相似度或类似的东西。损失的计算方式就和其他多分类问题一样。这里的类别就是词汇表的各个Token。诀窍是,我们只使用当时应该拥有的信息独立预测目标句子中每个Token的类。标签是目标句子中的实际Token。基于预测值和标签,我们可以通过交叉熵计算损失。(事实上,我们会 使标签“平滑 ”化来处理标签的非绝对性,同义词也能起到同样的作用。)

训练后,究竟输入什么来生成翻译?

不能一次性输入一些内容,让模型直接给出翻译结果。需要多次使用该模型。首先,把源句子输入到模型的编码器部分,得到一个抽象的、有深度的句子版本。然后,把编码信息和起始Token输入到模型的解码器部分。这就可以预测目标句子中的第二个Token。然后根据输入的内容和第二个 Token来预测第三个Token。重复这个动作,直到得到一个完整的翻译句子。(但事实上,模型会考虑每个位置使用多个高概率Token,每次输入多个候选序列,并根据总概率和长度损失选择最终翻译的句子。)

为什么有三个箭头进入多头注意力模块?

我猜有三个原因。1) 说明解码器中的第二个多头注意力区块的部分输入来自编码器,部分来自解码器中的前一个区块。2) 解释注意力算法是如何工作的。3)暗示在实际注意力发生之前,三个输入中的每一个都经历了各自独立的线性变换。

结论

它很漂亮!如果不是因为它非常有用,我可能不会这么想。我现在体会到了人们第一次看到它工作时的感觉。这个用很少代码就能表达的优雅、可训练的模型学会了如何翻译人类语言,并击败了几十年来建立的复杂机器翻译系统。它神奇、聪明,令人难以置信。你可以看到下一步是怎么说的,而不用在意翻译的句子。让我们在互联网上的每一点文本上使用这种技术—大语言模型就这样诞生了!

如果你觉得上面有一些错误,请告知我。

除非另有说明,否则所有图片均由作者提供,包含作者对 Attention Is All You Need 中的图表的注释 。


编辑:黄继彦





译者简介





作者简介

周梓溢,广州大学统计学在读学生,数据科学爱好者。在学习中时常翻阅数据科学英文文献,一直在学习的路上,希望在学习过程输出一些有意义的事情。很高兴加入数据派THU翻译组这个大家庭,期望与大家共同探索数据科学,一起「无限进步」!

翻译组招募信息

工作内容:需要一颗细致的心,将选取好的外文文章翻译成流畅的中文。如果你是数据科学/统计学/计算机类的留学生,或在海外从事相关工作,或对自己外语水平有信心的朋友欢迎加入翻译小组。

你能得到:定期的翻译培训提高志愿者的翻译水平,提高对于数据科学前沿的认知,海外的朋友可以和国内技术应用发展保持联系,THU数据派产学研的背景为志愿者带来好的发展机遇。

其他福利:来自于名企的数据科学工作者,北大清华以及海外等名校学生他们都将成为你在翻译小组的伙伴。


点击文末“阅读原文”加入数据派团队~



转载须知

如需转载,请在开篇显著位置注明作者和出处(转自:数据派ID:DatapiTHU),并在文章结尾放置数据派醒目二维码。有原创标识文章,请发送【文章名称-待授权公众号名称及ID】至联系邮箱,申请白名单授权并按要求编辑。

发布后请将链接反馈至联系邮箱(见下方)。未经许可的转载以及改编者,我们将依法追究其法律责任。





关于我们

数据派THU作为数据科学类公众号,背靠清华大学大数据研究中心,分享前沿数据科学与大数据技术创新研究动态、持续传播数据科学知识,努力建设数据人才聚集平台、打造中国大数据最强集团军。



新浪微博:@数据派THU

微信视频号:数据派THU

今日头条:数据派THU

点击“阅读原文”拥抱组织



我觉得多头注意力就像一个团队合作,每个人负责关注不同的细节。 这样可以避免模型过于关注某个局部信息,从而更好地捕捉全局的依赖关系。就好比侦探破案,有人负责收集线索,有人负责分析动机,最后大家一起合作才能找到真凶。

从数学角度看,多头注意力通过将query, key, value 投影到不同的子空间,使得模型可以学习到不同的注意力模式。每个头关注不同的特征组合,最终concat起来,可以提升模型的表达能力. 类似于卷积神经网络中多个卷积核的作用,每个卷积核提取不同的特征,最终组合起来形成更丰富的特征表示。

想起之前玩游戏王的时候,有些卡牌效果是“无效对方场上所有XX卡的效果”,这里的“无效”其实就可以理解为一种mask。所以我觉得mask本质上就是一种选择性的屏蔽机制,可以根据不同的需求来灵活地控制信息的流动。 比如在图像处理中,我们可以用mask来抠图,把图片中的特定区域提取出来。感觉NLP和图像处理很多地方都是相通的。

掩码在数据隐私保护上也很有潜力。我们可以使用掩码技术隐藏敏感信息,同时允许模型学习数据的总体分布。比方说,在医疗数据分析中,我们可以掩盖患者的姓名和住址等个人信息,但保留疾病和治疗方案等信息,从而在保护患者隐私的同时,让模型能够用于疾病预测和治疗优化。