Attention 论文笔记
Attention
梦开始的地方:Attention Is All You Need
本文不再介绍背景等信息,而是重点关注具体算法与实现
注意力机制
一维注意力
注意力分数指的是:对于一个输入 Q,通过和 K 计算某种关系,得到 V 的权重,最后权重乘以 V 得到最后的分数
举个例子,对于这个公式,其中 $\alpha(x, x_i)$ 是注意权重,$-\frac{1}{2}(x - x_i)^2$ 表示的是注意力分数:
$$ f(x) = \sum_i \alpha(x, x_i) y_i = \sum_{i=1}^{n} \operatorname{softmax}\!\left(-\frac{1}{2}(x - x_i)^2\right) y_i $$
上式是计算 $x$ 的注意力分数的公式,表示输入 当前输入 $x$ 去和一组记忆位置 $x_i$ 做相似度比较,得到权重
$\alpha(x,x_i)$,再用这些权重对对应的 $y_i$ 加权求和。
直观理解就是,如果 $x$ 和 $x_i$ 非常近,那么对应的权重就会上升,所以对应的输出就会更大
向量化
扩展到向量后,形式几乎不变,但是差平方替换为向量的平方范数 $\|x-x_i\|^2 = \sum_{m=1}^d (x_m-x_{i,m})^2$ ,也就是所有维度的差的平方求和:
$$ f(x)=\sum_{i=1}^n \alpha_i(x)\, y_i, \quad \alpha_i(x)=\frac{\exp\left(-\frac12 \|x-x_i\|^2\right)} {\sum_{j=1}^n \exp\left(-\frac12 \|x-x_j\|^2\right)} $$
上式就是带入了一个 Softmax 的结果
Scaled Dot-Product Attention
最终的目标是缩放点积注意力公式
$$ \mathrm{Attention}(Q,K,V)=\mathrm{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V $$
单个 Query 情况
1 个 query:$q \in \mathbb{R}^{d_k}$,$n$ 个 key:$k_1,\dots,k_n \in \mathbb{R}^{d_k}$,对应 $n$ 个 value:$v_1,\dots,v_n \in \mathbb{R}^{d_v}$
- 计算 query 和每一个 key 的相似度,也就是 $s_i = q \cdot k_i = q^\top k_i$,也就是点积,越大表示向量方向越近
- 缩放:$s = [s_1,s_2,\dots,s_n]$ 执行缩放后为 $\tilde{s}_i = \frac{s_i}{\sqrt{d_k}}$ ,也就是对维度缩放一下
- Softmax:计算成概率权重,$\alpha_i = \frac{e^{\tilde{s}_i}}{\sum_{j=1}^n e^{\tilde{s}_j}}$ 得到一组权重(就是注意力分数)满足 $\alpha_1,\alpha_2,\dots,\alpha_n$,$\alpha_i \ge 0,\quad \sum_i \alpha_i = 1$
- 加权求和:$o = \sum_{i=1}^n \alpha_i v_i$
为什么使用点积?
- 点积对于多个 query 和 key,可以用矩阵并行化一次运算完毕
- 本身具有相似度的含义:$q^\top k = \|q\|\|k\|\cos\theta$
为什么要缩放?
一个很常见的八股文问题,需要掌握数学推导
如果 $q$ 和 $k$ 的每个分量都大致是均值 0、方差 1 的随机变量,那么点积 $q^\top k = \sum_{m=1}^{d_k} q_m k_m$ 是 是 $d_k$ 项的求和
如果每一项的方差大致是 1,那么总方差会随维度增长,大约是:$\mathrm{Var}(q^\top k) \propto d_k$
关于高斯方差累计,对于独立的随机变量有 $\operatorname{Var}\left(\sum_{i=1}^{n} X_i\right)=\sum_{i=1}^{n}\operatorname{Var}(X_i)$
Softmax 对输入的尺度非常敏感,所以大方差会导致指数迅速拉开差距,方差大后几乎会退化成 one-hot 的形式
对 Softmax 计算梯度得到
$$ p_i = \frac{e^{z_i}}{\sum_j e^{z_j}}, \quad \frac{\partial p_i}{\partial z_k}=p_i(\delta_{ik}-p_k) $$
其中 $\delta_{ik}$ 在 $i = k$ 等于 1,否则为 0,对于尖锐的 Softmax 有 $p_m \approx 1,\quad p_j \approx 0\ (j\neq m)$
代入得到:
- 对最大那个位置:$\frac{\partial p_m}{\partial z_m}=p_m(1-p_m)\approx 1\cdot 0=0$
- 对其他位置:$\frac{\partial p_j}{\partial z_j}=p_j(1-p_j)\approx 0$
- 交叉项:$\frac{\partial p_i}{\partial z_k}=-p_i p_k \approx 0$
整个 Softmax 的雅可比矩阵几乎全部都是 0,这也就是梯度消失
矩阵形式
尤其注意一下这里的维度变化
对于输入的 $Q \in \mathbb{R}^{n \times d_k}$,$K \in \mathbb{R}^{n \times d_k}$, $V \in \mathbb{R}^{n \times d_v}$ 三个矩阵计算 Attention:
- 计算分数矩阵:$S = QK^\top$ ,其中$S \in \mathbb{R}^{n \times n}$, $S_{ij} = q_i^\top k_j$ (就是一维的值分布到了各个矩阵位置上)
- 缩放:$\hat{S} = \frac{S}{\sqrt{d_k}}$
- 有时候还会有掩码,也就是不允许看到未来的 token,负无穷到 softmax 分子是 0:
$$ \begin{array}{c} \hat{S}_{ij} = \begin{cases} \hat{S}_{ij}, & \text{允许关注}\\ -\infty, & \text{不允许关注} \end{cases} \end{array} $$
- Softmax:对每一行做 $A = \mathrm{softmax}(\hat{S})$ ,这里没有发生维度变化$A \in \mathbb{R}^{n \times n}$,只是改成了概率分布
- 输出:$O = AV$,其中$O \in \mathbb{R}^{n \times d_v}$ ,$o_i = \sum_{j=1}^n A_{ij} v_j$ 表示第 $i$ 个位置从全序列汇总得到的新表示
实际上输入最基本的维度要求是这样的:
- $Q \in \mathbb{R}^{n_q \times d_k}$
- $K \in \mathbb{R}^{n_k \times d_k}$
- $V \in \mathbb{R}^{n_k \times d_v}$
主要有两点要求:
- $Q$ 和 $K$ 的内积维 $d_k$ 必定相等,因为要做点积
- $K$ 和 $V$ 序列长度 $n_k$ 必须一样,但是特征维度$d_k$ $d_v$可以不相同
为什么 Softmax 对行做?
因为每一个行对应一个 query, $[S_{i1},S_{i2},...,S_{in}]$ 表示第 $i$ 个 query 对所有 key 的打分
Self-Attention & Cross-Attention
自注意力指的是,所有的 QKV 都来自与一个 $X$
对于序列长度是 $n$,每个 token 的输入表示维度是 $d_{\text{model}}$,有输入矩阵$X \in \mathbb{R}^{n \times d_{\text{model}}}$
将输入经过三个投影矩阵计算后得到:
$$ Q = XW_Q,\quad K = XW_K,\quad V = XW_V $$
各自的维度是: $W_Q \in \mathbb{R}^{d_{\text{model}} \times d_k}$, $W_K \in \mathbb{R}^{d_{\text{model}} \times d_k}$, $W_V \in \mathbb{R}^{d_{\text{model}} \times d_v}$
后续的算法都是跟矩阵形式是一样的了
而 交叉注意力唯一的区别就是:
$$ Q=X_1W_Q,\quad K=X_2W_K,\quad V=X_2W_V $$
多头注意力
简介
普通的单头注意力只有一套 $W_Q,W_K,W_V$,因此只能在一个子空间里做一次注意力匹配,一次只能学习到一次关系模式
但是在同一句话中模型需要同时关注多种关系,因此可以投射子空间来强化理解能力
实现
多头注意力的做法是把 $Q,K,V$ 分别投影到 多个不同的子空间,每个子空间各自做一次注意力,最后再拼接起来。
对于输入 $X \in \mathbb{R}^{n \times d_{\text{model}}}$,头数 $h$,每个头的维度 $d_k=d_v=d_{\text{model}}/h$(注意这里 Q 和 K 的最后一维必相同),第 $i$ 个头有:
$$ \text{head}_i = \mathrm{Attention}(Q_i, K_i, V_i) $$
其中:$W_Q^{(i)} \in \mathbb{R}^{d_{\text{model}} \times d_k}$, $W_K^{(i)} \in \mathbb{R}^{d_{\text{model}} \times d_k}$ $W_V^{(i)} \in \mathbb{R}^{d_{\text{model}} \times d_v}$ 计算得到
$$ Q_i=XW_Q^{(i)} \in \mathbb{R}^{n \times d_k}, K_i=XW_K^{(i)} \in \mathbb{R}^{n \times d_k}, V_i=XW_V^{(i)} \in \mathbb{R}^{n \times d_v} $$
计算注意力权重:
$$ A_i=\mathrm{softmax}\left(\frac{Q_iK_i^T}{\sqrt{d_k}}\right), Q_iK_i^T \in \mathbb{R}^{n \times n} $$
因此每一个头都有自己的注意力矩阵,每个头运算得到结果:
$$ \text{head}_i=A_iV_i \in \mathbb{R}^{n \times d_v} $$
执行矩阵拼接得到 $H=\mathrm{Concat}(\text{head}_1,\dots,\text{head}_h)\in \mathbb{R}^{n \times (h d_v)}$ (也就是第二个维度执行左右拼接),最后乘以输出矩阵$W_O$
目的是把所有头拼接后的结果再映射回模型维度,得到 $Y \in \mathbb{R}^{n \times d_{\text{model}}}$
$$ Y=HW_O,\quad W_O\in\mathbb{R}^{h d_v \times d_{\text{model}}} $$
这里 $W_O$ 类似于整合作用,把所有的信息混合到一个统一的表示
位置编码
Attention 架构中无法感知所有 token 之间的顺序,因此需要位置编码结合到 Embedding 中,让模型感知到 token 的位置
一般而言:对一个 token 的向量$x_i \in \mathbb{R}^{d_{\text{model}}}$,加入位置编码后成为 $z_i = x_i + PE(i)$
总体而言 PE 分为两大类,绝对位置编码和相对位置编码
绝对指的是,0123 这种绝对位置,Attention 论文中用的就是固定的正余弦位置编码;相对指的是表达两个 token 之间的相对位置
在论文中定义 PE 为下式,其中$pos$是位置,$i$是维度索引的一半,$d_{\text{model}}$是模型维度
$$ PE(pos, 2i) = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) $$
$$ PE(pos, 2i+1) = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) $$
为什么用正余弦?
- 最基本的:不同的位置可以得到不同的值,这可以区分位置
- 编码具有连续性:位置相近则编码结果相似,距离远则位置更远
- 可以学习相对位置:$\sin(a+b),\cos(a+b)$ 可以通过由 $\sin a,\cos a$ 与偏移量 $b$ 的关系表示(三角和公式)
- 可以外推:因为是公式生成的,所以扩展到训练中没有见过的更长位置
Transformer 架构

输入层
指的是 Embedding + Position Encoding 向量模块,整体流程是:
- tokenizer 将整个句子切分一下,常见的方式有 BPE 组合
- Embedding,将每一个 token 映射为一个向量 $x_i \in \mathbb{R}^{d_{\text{model}}}$,隐藏维度 $d_{\text{model}}$,序列长度$n$
- PE 加到 Embedding 结果中:$Z = X + PE$, $Z$ 是 Encoder 的输入
在论文中超参数 $d_{\text{model}} = 512$
Encoder
整个 Encoder 包括以下这些结构,构成一个 Block,在论文中是堆叠了 $N = 6$ 次
- Multi-Head Attention(是 Self-Attention)
- Add & Norm
- Position-wise Feed Forward Network
- Add & Norm
Transformer 各个层的序列的长度和维度全都不变,隐藏维度也保持不变,所以很方便堆叠很多层
MHA
使用 Self-Attention 实现的 MHA,QKV 来自于同一个 $X$,也就是上文中的 $Z$
先执行线性映射,对输入映射到 $Q = XW^Q,\quad K = XW^K,\quad V = XW^V$
然后按照 MHA 流程切分到$h$个头(论文中超参数 $h = 8$),每一个头计算 $\text{head}_h = \text{softmax}\left(\frac{Q_hK_h^T}{\sqrt{d_k}}\right)V_h$
最后执行拼接 $\text{MultiHead}(X) = \text{Concat}(\text{head}_1,\dots,\text{head}_H)W^O$ ,输出维度仍然是:$\mathbb{R}^{n \times d_{\text{model}}}$
Add & Norm
残差连接(Residual Connection)+ 层归一化(LayerNorm):
$$ \text{LayerNorm}(X + \text{Sublayer}(X)) $$
其中 $\text{Sublayer}(X)$ 是上一个 MHA 的变换后的输出结果
Layer Norm 算法指的是,对于 $x = [x_1, x_2, \dots, x_d]$ 序列计算均值和方差:$\mu = \frac{1}{d}\sum_{i=1}^{d} x_i$, $\sigma^2 = \frac{1}{d}\sum_{i=1}^{d}(x_i - \mu)^2$,执行归一化:
$$ \hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} $$
其中 $\epsilon$ 是一个很小的数,防止除零。之后再使用一个可学习的额缩放和平移:
$$ y_i = \gamma_i \hat{x}_i + \beta_i $$
作用是先执行标准化,之后再让模型学习一个更加合适的分布
为什么用 Layer Norm 而不是 Batch Norm 也是个很常见的问题,这里省略
残差连接的功能是帮助梯度传播,减轻网络太深导致的退化问题。(为什么呢?)
Feed Forward Network + Add & Norm
前馈网络,指的是对每一个位置做一个相同的 MLP
$$ \text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2 $$
原论文用的是两层线性层,中间 ReLU,第一层:从 $d_{\text{model}}$ 升到 $d_{ff}$;第二层:从 $d_{ff}$ 降回 $d_{\text{model}}$,就是一个多层感知机,论文中这个超参数 $d_{ff} = 2048$
Attention 的作用是 token 之间的信息交互,FFN 的作用是让 token 与对自己的做非线性变换
在 FFN 后接一个残差连接,承接作用
Decoder
Decoder 也是堆叠 $N = 6$ 层,但每层比 Encoder 多一个注意力模块,包含(这里把 Add & Norm 结合到上一层了)后文省略 Add & Norm:
- Masked Multi-Head Self-Attention + Add & Norm
- Encoder-Decoder Attention(Cross-Attention)+ Add & Norm
- Feed Forward + Add & Norm
因为 Decoder 不仅仅需要自己的生成信息,还需要输入句子的相关信息,SA 负责看已经生成的 token 的信息,CA 负责查看 Encoder 的内容
MHA
Decoder 的输入是当前输入的是将目标整体向右移动一位的输入,也就是:
对于目标:
<bos> 我 喜欢 学习 Transformer <eos>Decoder 的输入是:
<bos> 我 喜欢 学习 TransformerDecoder 的期望输出和监督目标是(<bos>是启动生成的编码):
我 喜欢 学习 Transformer <eos>因为模型不应该看到预测的目标,所以这里会出现一个 Mask 部分,实际上就是累加一个下三角矩阵 $M$ :
$$ \begin{array}{c} M_{ij} = \begin{cases} 0, & j \le i \\ -\infty, & j > i \end{cases} \end{array} $$
回顾一下 负无穷在 Softmax 的输出就是 0,这保证了 Decoder 的输出是自回归的,不依赖未来
Cross-Attention
Decoder 相对于 Encoder 最大的区别在这里,这里的 $Q$ 来自 Decoder 当前隐状态,但是 $K,V$ 来自 Encoder 输出:
$$ Q = H_{\text{dec}} W_Q,\quad K = H_{\text{enc}} W_K,\quad V = H_{\text{enc}} W_V $$
$H_{\text{enc}}$:Encoder 最后一层输出的整段序列表示,$H_{\text{dec}}$:Decoder 在进入 cross-attention 前的输入表示
注意这里$W_Q, W_K, W_V$是这一层 cross-attention 自己学习的参数,不是 Encoder 中缓存的 KV
CA 作用是让 Decoder 在生成的时候会额外考虑输入句子的相关内容,类似于Decoder 在边生成边对输入做检索
FFN
最后在输出的地方执行 FFN + Add & Norm 步骤,作用同样是 token 与自己交互
输出层
Decoder 最后一层的输出是 $Y \in \mathbb{R}^{n \times d_{\text{model}}}$ ,通过一个线性层,映射到词表中:
$$ \text{logits} = YW_{vocab} + b $$
对于词表大小是 $V$,则有:$\text{logits} \in \mathbb{R}^{n \times V}$
最后对每一个位置执行 Softmax 可以计算出每一个词语的概率:
$$ P(y_t \mid y_{<t}, x) = \text{softmax}(\text{logits}_t) $$
KV Cache
论文中没有直接提出这个,但是这是一个非常常用的工程优化方案
简介
对于一个 Decoder 生成中过程,假设已经生成了: $y_1, y_2, y_3$ 现在开始预测 $y_4$
没有 Cache 的 Decoder,那么会把整个序列 $[y_1, y_2, y_3]$, 重新送进模型,再算一次 self-attention。
等要预测 $y_5$ 时,又把:$[y_1, y_2, y_3, y_4]$整个再算一遍,于是前面那些 token 的 $K,V$ 会被重复计算很多次。
KV Cache 的核心思想是:历史 token 的 Key 和 Value 一旦算出来,后续生成时就不变,缓存起来复用。
(感觉有点像 DP 里面的记忆化搜索hhhh,简而言之就是缓存减少重复运算)
具体对象
对于某一个 Decoder 中的某一个 Attention 模块,有隐藏状态:$X \in \mathbb{R}^{T \times d_{\text{model}}}$, 经过线性映射得到:
$$ Q = XW^Q,\quad K = XW^K,\quad V = XW^V $$
经过 MHA 以及多 Batch 得到:$Q, K, V \in \mathbb{R}^{B \times H \times T \times d_{\text{head}}}$
KV Cache 就是缓存每一层历史里的历史位置:$K_{\text{past}}, V_{\text{past}}$,
也就是: $K_{\text{cache}} \in \mathbb{R}^{B \times H \times T_{\text{past}} \times d_{\text{head}}}$ , $V_{\text{cache}} \in \mathbb{R}^{B \times H \times T_{\text{past}} \times d_{\text{head}}}$
计算例子
如果是在 Transformer decoder 推理 里做 KV cache,缓存空大小:
$$ \text{Cache bytes} = \text{N} \times (\text{需要缓存的 attention 模块数/层}) \times 2 \times B \times n \times \text{bytes\_per\_elem} $$
- 层数:Decoder Block 个数,经典值 $N = 6$
- Attention 个数:也就是每一个 block 有几个 Attention 模块
- 2 :这个 2 指的是 KV 各要一份缓存空间,所以一个 Attention 模块是 2 份
- $B$:Batch Size
- $n$ :序列长度
- $\text{bytes\_per\_elem}$:每一个数据的结构大小
实际上严谨一些的话也不完全对,在 Decoder 架构中的 SA 和 CA 的长度并不一致:
SA 缓存的是 decoder 已生成序列,长度是 $n$;CA的 $K,V$ 来自 encoder 输出,长度应该是源序列长度,记为 $m$
更加严格的写法是:$(2nd + 2md) \times B \times N \times (\text{需要缓存的 attention 模块数/层})$