旋转编码 RoPE
旋转编码是苏剑林老师在论文 Roformer: Enhanced Transformer With Rotray Position Embedding 中提出的,苏神是怎么想到这个idea的详细可以参考他的博采众长的旋转式位置编码, 其他位置编码的调研则参考了苏神的 让研究人员绞尽脑汁的Transformer位置编码。下面是笔者从博客中简单做的总结,对于懒得看数学推导的朋友们可以就看这个结论:
因为只靠Attention模块无法捕捉输入的位置信息。换句话说,将输入随意做permute,再对Attention的输出做unpermute,获得的结果都是一样的。这是一件比较糟糕的事情,为此,我们需要引入位置编码。RoPE的核心思想是”通过绝对位置编码的方式实现相对位置编码”。这句话可能还是有一点点抽象,让笔者详细解释一下: 相对位置编码的目标是让第i个词和第j个词之间距离只和(i - j)相关,传统相对位置编码的实现方式是主要是想办法对Attention计算中间过程里 QK 这个矩阵做一点操作,从而保证 QK 矩阵里的每一行的第i个元素和第j个元素之间有一个和(i - j)相关的差;绝对位置编码的意思就是不去修改Attention里的矩阵,只修改Attention里的输入。而RoPE就只修改了输入,同时还能保留相对位置信息。
既然把”只修改输入前提下保留相对位置”这句话翻译成数学的语言就是下面这行公式。
- 其中f就是我们需要求解的函数,这个函数以向量和它的位置作为输入,输出一个加了位置信息的向量。
- 因为Attention计算里QK的本质就是向量之间两两做内积,所以修改Attention里矩阵和要求内积结果只和(m - n)相关是等价的。 \(\langle\boldsymbol{f}(\boldsymbol{q}, m), \boldsymbol{f}(\boldsymbol{k}, n)\rangle=g(\boldsymbol{q}, \boldsymbol{k}, m-n) \\ \boldsymbol{f}(\boldsymbol{q}, 0) = \boldsymbol{q}\) 求解这个公式,得到结果下面这个函数是上面这个方程的一个解,OK,我们就得到了RoPE了 \(\left(\begin{array}{ccccccc} \cos m \theta_0 & -\sin m \theta_0 & 0 & 0 & \cdots & 0 & 0 \\ \sin m \theta_0 & \cos m \theta_0 & 0 & 0 & \cdots & 0 & 0 \\ 0 & 0 & \cos m \theta_1 & -\sin m \theta_1 & \cdots & 0 & 0 \\ 0 & 0 & \sin m \theta_1 & \cos m \theta_1 & \cdots & 0 & 0 \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & 0 & 0 & \cdots & \cos m \theta_{d / 2-1} & -\sin m \theta_{d / 2-1} \\ 0 & 0 & 0 & 0 & \cdots & \sin m \theta_{d / 2-1} & \cos m \theta_{d / 2-1} \end{array}\right)\left(\begin{array}{c} q_0 \\ q_1 \\ q_2 \\ q_3 \\ \vdots \\ q_{d-2} \\ q_{d-1} \end{array}\right)\)
上面这个函数又等价于下面这个函数,下面这个函数里的矩阵都是Dense的,因此算起来更高效,所以在后文计算的时候都以下面这个矩阵为准。其中 $ \theta_i=10000^{-2 i / d} $,这点和最经典的Sinusoidal保持一致。 \(\left(\begin{array}{c} q_0 \\ q_1 \\ q_2 \\ q_3 \\ \vdots \\ q_{d-2} \\ q_{d-1} \end{array}\right) \otimes\left(\begin{array}{c} \cos m \theta_0 \\ \cos m \theta_0 \\ \cos m \theta_1 \\ \cos m \theta_1 \\ \vdots \\ \cos m \theta_{d / 2-1} \\ \cos m \theta_{d / 2-1} \end{array}\right)+\left(\begin{array}{c} -q_1 \\ q_0 \\ -q_3 \\ q_2 \\ \vdots \\ -q_{d-1} \\ q_{d-2} \end{array}\right) \otimes\left(\begin{array}{c} \sin m \theta_0 \\ \sin m \theta_0 \\ \sin m \theta_1 \\ \sin m \theta_1 \\ \vdots \\ \sin m \theta_{d / 2-1} \\ \sin m \theta_{d / 2-1} \end{array}\right)\)
另外,在实现上为了更高效,Llama其实是这么实现的,下面也以这种实现为基准。 \(\left(\begin{array}{c} q_0 \\ q_1 \\ \vdots \\ q_{d/2-1} \\ q_{d/2} \\ \vdots \\ q_{d-2} \\ q_{d-1} \end{array}\right) \otimes\left(\begin{array}{c} \cos m \theta_0 \\ \cos m \theta_1 \\ \vdots \\ \cos m \theta_{d / 2-1} \\ \cos m \theta_{0} \\ \vdots \\ \cos m \theta_{d / 2-2} \\ \cos m \theta_{d / 2-1} \end{array}\right)+\left(\begin{array}{c} -q_0 \\ -q_1 \\ \vdots \\ -q_{d/2-1} \\ q_{d/2} \\ \vdots \\ q_{d-2} \\ q_{d-1} \end{array}\right) \otimes\left(\begin{array}{c} \sin m \theta_0 \\ \sin m \theta_1 \\ \vdots \\ \sin m \theta_{d / 2-1} \\ \sin m \theta_{0} \\ \vdots \\ \sin m \theta_{d / 2-2} \\ \sin m \theta_{d / 2-1} \end{array}\right)\) That’s It! 下面就让我们来实现一版把!
RoPE 实现
让我们以2025年1月初Megatron-LM里的一个实现作为baseline,链接。下面简单写一个伪代码。
- 公式里的m通过
torch.arrange(seq_len)
生成。torch.arrange
的结果完全可以复用,因此可以Cache - 公式里的 theta 对应
inv_freq
,其结果也可以复用 - outer 的输入是两个1维的向量,输出是一个二维的向量,其含义是输入的笛卡尔积
_rotate_half
生成最后那个公式里右边的那部分。1 2 3 4 5 6 7 8
def _rotate_half(x): x1, x2 = torch.chunk(x, 2, dim=-1) return torch.cat((-x2, x1), dim=-1) inv_freq = 1.0 / (rotary_base ** (torch.arange(0, dim, 2) / dim)) freqs = torch.outer(torch.arange(seq_len), inv_freq) cos_ = torch.cos(freqs) sin_ = torch.sin(freqs) t = (t * cos_) + (_rotate_half(t) * sin_)
Cuda 实现
Apex,TransformerEngine,FlashAttention都提供了API,具体依然可以参考Megatron里的代码。本文的目的既不是想探究API怎么用(这部分看Megatron代码就好了);也懒得自己写一个CUDA代码了,因此直接跑一个实验,看看这几个API实现的效果如何,从而大概了解一下这种类型的函数应该怎么实现。 首先,下面是不同API在A100上的实验结果。可以看到FlashAttention的实现其实还差一点, Apex和TransformerEngine的实现是最优的。恰好笔者对TransformerEngine这个库比较熟,就直接看看这里是怎么实现的吧。
RotaryEmbeddingMegatron take 73.508257 ms while size = 20971520
RotaryEmbeddingApex take 14.397789 ms while size = 20971520
RotaryEmbeddingTE take 14.080192 ms while size = 20971520
RotaryEmbeddingFlash take 21.117988 ms while size = 20971520
具体代码是FusedRoPEFunc
这个函数,一上来就是巨大的槽点。也就是说,不管输入是什么类型,TransformerEngine都会先把他转成Float32。那么很自然的大家就会想到,如果输入是BFloat16,正确性还能保证吗?笔者做了实验,发现把Input的Dtype从Float改成BFloat16之后,FlashAttention和TransformerEngine都和baseline对不齐。emmm只能说世界就是巨大的草台班子,大家调用这种API之前最好还是本地实测一下。至于Megatron很机智的用了Apex的实现做了默认的实现,不知道是不是出于这个正确性的考虑。
if freqs.dtype != torch.float32:
freqs = freqs.float()
让我们再看看具体CUDA实现,毕竟这个实现性能还是很好的。 让我们无视掉每个变量具体的含义,快速的看一些High level Idea,思想很简单,就是每个线程去做Load/Store。这个函数是典型的Memory Bound的函数,因此sincosf这个计算绝对是Free的。这也解释了,为什么FlashAttention的实现是提前算好sin/cos,而TransformerEngine是每次计算sin/cos,反而是后者的性能更好,答案是因为计算根本就是free的。我看了FlashAttention的代码,他是用Triton实现的,那确实性能比CUDA差是符合预期的。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
float v_cos, v_sin;
sincosf(freqs[s_id * d2 + d_id], &v_sin, &v_cos);
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
float v_src = src[offset_src];
float v_src_rotate = (d_id + d2 / 2 < d2)
? -static_cast<float>(src[offset_src + (d2 / 2) * stride_d])
: static_cast<float>(src[offset_src + (d2 / 2 - d2) * stride_d]);
dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
}
}
本文就到此结束了,代码都在笔者的这个github仓库里。如果在想深究,追求一个极致,还可以想办法通过float4这种来加速Load/Store。或许日后某天会补上吧。