Redian新闻
>
十分钟读懂旋转编码(RoPE)

十分钟读懂旋转编码(RoPE)

公众号新闻

©作者 | 绝密伏击

单位 | 奇虎360高级算法专家


旋转位置编码(Rotary Position Embedding,RoPE)是论文 Roformer: Enhanced Transformer With Rotray Position Embedding 提出的一种能够将相对位置信息依赖集成到 self-attention 中并提升 transformer 架构性能的位置编码方式。而目前很火的 LLaMA、GLM 模型也是采用该位置编码方式。


和相对位置编码相比,RoPE 具有更好的外推性,目前是大模型相对位置编码中应用最广的方式之一。


备注:什么是大模型外推性?


外推性是指大模型在训练时和预测时的输入长度不一致,导致模型的泛化能力下降的问题。例如,如果一个模型在训练时只使用了 512 个 token 的文本,那么在预测时如果输入超过 512 个 token,模型可能无法正确处理。这就限制了大模型在处理长文本或多轮对话等任务时的效果。




旋转编码RoPE


1.1 基本概念


在介绍 RoPE 之前,先给出一些符号定义,以及基本背景。


首先定义一个长度为 的输入序列为:

其中 表示输入序列中第 个 token,而输入序列 对应的 embedding 表示为:

其中 表示第 个 token 对应的 维词嵌入向量。

接着在做 self-attention 之前,会用词嵌入向量计算 向量同时加入位置信息,函数公式表达如下:

其中 表示第 个 token 对应的词向量 集成位置信息 之后的 query 向量。而 则表示第 个 token 对应的词向量 集成位置信息 之后的 key 和 value 向量。

而基于 transformer 的位置编码方法都是着重于构造一个合适的 函数形式。

而计算第 个词嵌入向量 对应的 self-attention 输出结果,就是 和其他 都计算一个 attention score ,然后再将 attention score 乘以对应的 再求和得到输出向量

1.2 绝对位置编码


对于位置编码,常规的做法是在计算 query,key 和 value 向量之前,会计算一个位置编码向量 加到词嵌入 上,位置编码向量 同样也是 维向量,然后再乘以对应的变换矩阵

而经典的位置编码向量 的计算方式是使用 Sinusoidal 函数:

其中 表示位置 维度向量 中的第 位置分量也就是偶数索引位置的计算公式,而 就对应第 位置分量也就是奇数索引位置的计算公式。


1.3 2维旋转位置编码


论文中提出为了能利用上 token 之间的相对位置信息,假定 query 向量 和 key 向量 之间的内积操作可以被一个函数 表示,该函数 的输入是词嵌入向量 和它们之间的相对位置

接下来的目标就是找到一个等价的位置编码方式,从而使得上述关系成立。

假定现在词嵌入向量的维度是两维 ,这样就可以利用上 2 维度平面上的向量的几何性质,然后论文中提出了一个满足上述关系的 的形式如下:
这里面 Re 表示复数的实部。

进一步地, 可以表示成下面的式子:
看到这里会发现,这不就是 query 向量乘以了一个旋转矩阵吗?这就是为什么叫做旋转位置编码的原因。

同理, 可以表示成下面的式子
最终 可以表示如下:

关于上面公式(8)~(11)的具体推导,可以参见文章最后的附录,或者参考文章:一文看懂 LLaMA 中的旋转式位置编码(Rotary Position Embedding)。

1.4 扩展到多维


将2维推广到任意维度,可以表示如下:
内积满足线性叠加性,因此任意偶数维的 RoPE,我们都可以表示为二维情形的拼接,即
将 RoPE 应用到前面公式(4)的 Self-Attention 计算,可以得到包含相对位置信息的 Self-Attetion:

中,

值得指出的是,由于 是一个正交矩阵,它不会改变向量的模长,因此通常来说它不会改变原模型的稳定性。

1.5 RoPE 的高效计算


由于 的稀疏性,所以直接用矩阵乘法来实现会很浪费算力,推荐通过下述方式来实现 RoPE:

其中 是逐位对应相乘,即计算框架中的 运算。从这个实现也可以看到,RoPE 可以视为是乘性位置编码的变体。

总结来说,RoPE 的 self-attention 操作的流程是:对于 token 序列中的每个词嵌入向量,首先计算其对应的 query 和 key 向量,然后对每个 token 位置都计算对应的旋转位置编码,接着对每个 token 位置的 query 和 key 向量的元素按照两两一组应用旋转变换,最后再计算 query 和 key 之间的内积得到 self-attention 的计算结果。

论文中有个很直观的图片展示了旋转变换的过程:


1.6 远程衰减


可以看到,RoPE 形式上和前面公式(6)Sinusoidal 位置编码有点相似,只不过 Sinusoidal 位置编码是加性的,而 RoPE 可以视为乘性的。 的选择上,RoPE 同样沿用了 Sinusoidal 位置编码的方案,即 ,它可以带来一定的远程衰减性。


具体证明如下: 两两分组后,它们加上 RoPE 后的内积可以用复数乘法表示为:

并约定 ,那么由 Abel 变换(分部求和法)可以得到:

所以

因此我们可以考察 随着相对距离的变化情况来作为衰减性的体现:

从图中我们可以看到随着相对距离的变大,内积结果有衰减趋势的出现。因此,选择 ,确实能带来一定的远程衰减性。论文中还试过以 为初始化,将 视为可训练参数,然后训练一段时间后发现 并没有显著更新,因此干脆就直接固定 了。



RoPE实验


我们看一下 RoPE 在预训练阶段的实验效果:

从上面可以看出,增大序列长度,预训练的准确率反而有所提升,这体现了 RoPE 具有良好的外推能力。

下面是在下游任务上的实验结果:
其中 RoFormer 是一个绝对位置编码替换为 RoPE 的 WoBERT 模型,后面的参数(512)是微调时截断的maxlen,可以看到 RoPE 确实能较好地处理长文本语义。



RoPE代码实现


Meta 的 LLAMA 和 清华的 ChatGLM 都使用了 RoPE 编码,下面看一下具体实现。

3.1 在LLAMA中的实现


# 生成旋转矩阵
def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
    # 计算词向量元素两两分组之后,每组元素对应的旋转角度\theta_i
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    # 生成 token 序列索引 t = [0, 1,..., seq_len-1]
    t = torch.arange(seq_len, device=freqs.device)
    # freqs.shape = [seq_len, dim // 2] 
    freqs = torch.outer(t, freqs).float()  # 计算m * \theta

    # 计算结果是个复数向量
    # 假设 freqs = [x, y]
    # 则 freqs_cis = [cos(x) + sin(x)i, cos(y) + sin(y)i]
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs) 
    return freqs_cis

# 旋转位置编码计算
def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
)
 -> Tuple[torch.Tensor, torch.Tensor]:
    # xq.shape = [batch_size, seq_len, dim]
    # xq_.shape = [batch_size, seq_len, dim // 2, 2]
    xq_ = xq.float().reshape(*xq.shape[:-1], -12)
    xk_ = xk.float().reshape(*xk.shape[:-1], -12)

    # 转为复数域
    xq_ = torch.view_as_complex(xq_)
    xk_ = torch.view_as_complex(xk_)

    # 应用旋转操作,然后将结果转回实数域
    # xq_out.shape = [batch_size, seq_len, dim]
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2)
    return xq_out.type_as(xq), xk_out.type_as(xk)

class Attention(nn.Module):
    def __init__(selfargs: ModelArgs):
        super().__init__()

        self.wq = Linear(...)
        self.wk = Linear(...)
        self.wv = Linear(...)

        self.freqs_cis = precompute_freqs_cis(dim, max_seq_len * 2)

    def forward(selfx: torch.Tensor):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(batch_size, seq_len, dim)
        xk = xk.view(batch_size, seq_len, dim)
        xv = xv.view(batch_size, seq_len, dim)

        # attention 操作之前,应用旋转位置编码
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        # scores.shape = (bs, seqlen, seqlen)
        scores = torch.matmul(xq, xk.transpose(12)) / math.sqrt(dim)
        scores = F.softmax(scores.float(), dim=-1)
        output = torch.matmul(scores, xv)  # (batch_size, seq_len, dim)
  # ......


这里举一个例子,假设 batch_size=10, seq_len=3, d=8,则调用函数 precompute_freqs_cis(d, seq_len) 后,生成结果为:


In [239]freqs_cis
Out[239]
tensor([[ 1.0000+0.0000j,  1.0000+0.0000j,  1.0000+0.0000j,  1.0000+0.0000j],
        [ 0.5403+0.8415j,  0.9950+0.0998j,  0.9999+0.0100j,  1.0000+0.0010j],
        [-0.4161+0.9093j,  0.9801+0.1987j,  0.9998+0.0200j,  1.0000+0.0020j]])


以结果中的第二行为例(对应的 m = 1),也就是:
最终按照公式(12)可以得到编码之后的

注意:在代码中是直接用 freqs_cis[0] * xq_[0] 的结果表示第一个 token 对应的旋转编码(和公式 12 计算方式有所区别)。其中将原始的 query 向量 转换为了复数形式。


In [351]: q_ = q.float().reshape(*q.shape[:-1], -12)

In [352]: q_[0]
Out[352]: 
tensor([[[ 1.0247,  0.4782],
         [ 1.5593,  0.2119],
         [ 0.4175,  0.5309],
         [ 0.4858,  0.1850]]
,

        [[-1.7456,  0.6849],
         [ 0.3844,  1.1492],
         [ 0.1700,  0.2106],
         [ 0.5433,  0.2261]]
,

        [[-1.1206,  0.6969],
         [ 0.8371, -0.7765],
         [-0.3076,  0.1704],
         [-0.5999, -1.7029]]
])

In [353]: xq = torch.view_as_complex(q_)

In [354]: xq[0]
Out[354]: 
tensor([[ 1.0247+0.4782j,  1.5593+0.2119j,  0.4175+0.5309j,  0.4858+0.1850j],
        [-1.7456+0.6849j,  0.3844+1.1492j,  0.1700+0.2106j,  0.5433+0.2261j],
        [-1.1206+0.6969j,  0.8371-0.7765j, -0.3076+0.1704j, -0.5999-1.7029j]]
)

这里为什么可以这样计算?

主要是利用了复数的乘法性质。

我们首先来复习一下复数乘法的性质:

因此要计算:

可以转化为计算:

所以可以将公式(12)转化为两个复数的乘法运算。

3.2 在ChatGLM中的实现

和 LLAMA 的实现方式相差不大。代码如下:


class RotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
        super().__init__()
         # 计算 \theta_i
        inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
        inv_freq = inv_freq.half()

        self.learnable = learnable
        if learnable:
            self.inv_freq = torch.nn.Parameter(inv_freq)
            self.max_seq_len_cached = None
        else:
            self.register_buffer('inv_freq', inv_freq)
            self.max_seq_len_cached = None
            self.cos_cached = None
            self.sin_cached = None
        self.precision = precision

    def forward(self, x, seq_dim=1, seq_len=None):
        if seq_len is None:
            seq_len = x.shape[seq_dim]
        if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
            self.max_seq_len_cached = None if self.learnable else seq_len
            # 生成 token 序列索引 t = [0, 1,..., seq_len-1]
            t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
            # 对应m * \theta
            freqs = torch.einsum('i,j->ij', t, self.inv_freq)
            # 将 m * \theta 拼接两次,对应复数的实部和虚部
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            if self.precision == torch.bfloat16:
                emb = emb.float()

            # [sx, 1 (b * np), hn]
            cos_cached = emb.cos()[:, None, :]  # 计算得到cos(m*\theta)
            sin_cached = emb.sin()[:, None, :]  # 计算得到cos(m*\theta)
            if self.precision == torch.bfloat16:
                cos_cached = cos_cached.bfloat16()
                sin_cached = sin_cached.bfloat16()
            if self.learnable:
                return cos_cached, sin_cached
            self.cos_cached, self.sin_cached = cos_cached, sin_cached
        return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]

    def _apply(self, fn):
        if self.cos_cached is not None:
            self.cos_cached = fn(self.cos_cached)
        if self.sin_cached is not None:
            self.sin_cached = fn(self.sin_cached)
        return super()._apply(fn)

def rotate_half(x):
    x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
    return torch.cat((-x2, x1), dim=x1.ndim - 1)  



RoPE的外推性


我们都知道 RoPE 具有很好的外推性,前面的实验结果也证明了这一点。这里解释下具体原因。

RoPE 可以通过旋转矩阵来实现位置编码的外推,即可以通过旋转矩阵来生成超过预期训练长度的位置编码。这样可以提高模型的泛化能力和鲁棒性。

我们回顾一下 RoPE 的工作原理:假设我们有一个 维的绝对位置编码 ,其中 是位置索引。我们可以将 看成一个 维空间中的一个点。我们可以定义一个  维空间中的一个旋转矩阵 ,它可以将任意一个点沿着某个轴旋转一定的角度。我们可以用 来变换 ,得到一个新的点 。我们可以发现, 的距离是相等的,即 。这意味着 的相对关系没有改变。但是, 的距离可能发生改变,即 。这意味着 的相对关系有所改变。因此,我们可以用 来调整不同位置之间的相对关系。

如果我们想要生成超过预训练长度的位置编码,我们只需要用 来重复变换最后一个预训练位置编码 ,得到新的位置编码
依此类推。这样就可以得到任意长度的位置编码序列 ,其中 可以大于 。由于 是一个正交矩阵,它保证了 的距离不会无限增大或缩小,而是在一个有限范围内波动。这样就可以避免数值溢出或下溢的问题。同时,由于 是一个可逆矩阵,它保证了 的距离可以通过 的逆矩阵 还原到 的距离,即

这样就可以保证位置编码的可逆性和可解释性。

总结而言:

旋转编码 RoPE 可以有效地保持位置信息的相对关系,即相邻位置的编码之间有一定的相似性,而远离位置的编码之间有一定的差异性。这样可以增强模型对位置信息的感知和利用。这一点是其他绝对位置编码方式(如正弦位置编码、学习的位置编码等)所不具备的,因为它们只能表示绝对位置,而不能表示相对位置。

旋转编码 RoPE 可以通过旋转矩阵来实现位置编码的外推,即可以通过旋转矩阵来生成超过预训练长度的位置编码。这样可以提高模型的泛化能力和鲁棒性。这一点是其他固定位置编码方式(如正弦位置编码、固定相对位置编码等)所不具备的,因为它们只能表示预训练长度内的位置,而不能表示超过预训练长度的位置。

旋转编码 RoPE 可以与线性注意力机制兼容,即不需要额外的计算或参数来实现相对位置编码。这样可以降低模型的计算复杂度和内存消耗。这一点是其他混合位置编码方式(如 Transformer-XL、XLNet 等)所不具备的,因为它们需要额外的计算或参数来实现相对位置编码。



总结


最近一直听到旋转编码这个词,但是一直没有仔细看具体原理。今天花时间仔细看了一遍,确实理论写的比较完备,而且实验效果也不错。目前很多的大模型,都选择了使用了这种编码方式(LLAMA、GLM 等)。



附录


这里补充一下前面公式 1.3.2 节中,公式(8)~(11)是怎么推导出来的。

回到之前的公式(8),编码之后的 以及内积 的形式如下:

上面的公式为什么满足:

首先我们得先了解一下基本的复数相关知识。

首先看到上述 公式中有个指数函数: 

这个其实是欧拉公式,其中 表示任意实数, 是自然对数的底数, 是复数中的虚数单位,则根据欧拉公式有:

则是上述指数函数可以表示为实部为 ,虚部为 的一个复数,欧拉公式建立了指数函数、三角函数和复数之间的桥梁。

则上述 公式的
然后我们看回公式:
其中 是个二维矩阵, 是个二维向量,相乘的结果也是一个二维向量,这里用 表示:

然后首先将 表示成复数形式:
接着

其实就是两个复数相乘:

然后就有:

将结果重新表达成实数向量形式就是:

这里不难发现就是 query 向量乘以了一个旋转矩阵。

这就是为什么叫做旋转式位置编码的原因。

同理可得 key 向量

最后还有个函数
其中 表示一个复数 的实部部分,而 则表示复数 的共轭。

复习一下共轭复数的定义:

所以可得:

继续可得:

接下来我们就要证明函数 的计算公式是成立的。

首先回顾一下 attention 操作,位置 的 query 和位置 的 key 会做一个内积操作:

接着进行推导,我们整理一下:

这就证明上述关系是成立的,位置 的 query 和位置 的 key 的内积就是函数

把上面的式子用矩阵向量乘的形式来表达就是:


参考文献

[1] ROFORMER: ENHANCED TRANSFORMER WITH ROTARY POSITION EMBEDDING https://arxiv.org/pdf/2104.09864.pdf

[2] 梁德澎:一文看懂 LLaMA 中的旋转式位置编码(Rotary Position Embedding)https://zhuanlan.zhihu.com/p/642884818

[3] 马梦之:一步一步,推导旋转位置编码(Rotary Position Embedding, RoPE)https://zhuanlan.zhihu.com/p/644585013

[4] Transformer升级之路:博采众长的旋转式位置编码



更多阅读



#投 稿 通 道#

 让你的文字被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。


📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算


📬 投稿通道:

• 投稿邮箱:[email protected] 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿


△长按添加PaperWeekly小编



🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧



·
·


微信扫码关注该文公众号作者

戳这里提交新闻线索和高质量文章给我们。
相关阅读
《长安三万里》爆火:读懂了李白,也就读懂了人生读懂马斯克这30句话,就读懂了特斯拉一分钟读懂:诺奖得主为新冠疫苗做了啥贡献?京都麓酒店(Roku Kyoto, LXR Hotels & Resorts) 入住体验一文读懂OpenAI创始人的“世界币”比 GitHub Copilot 更强大?Meta 开源 AI 编码工具,能跨多语言补全和调试代码华为 Mate 60 Pro一分钟售罄;传 OpenAI 秘密训练GPT-5;中国全年汽车出口或超 500 万辆 | 极客早知道《封神》真的封神了:读懂了它,你就读懂了人性!探索 prompt 编码范式:如何优雅构建测试代码生成提示词?80% 代码秒生成!AI 神器 Copilot 大升级,百万开发者动嘴编码 5 年内成真面试官:如何设计API返回码(错误码)?VCE 低数—几何与测量例题讲解(Ray老师)读懂《围城》,就读懂了感情的3个真相医生如何“从中国内地直接”申请到美国绿卡:5分钟读懂,拒绝走冤枉路!​Transformer升级之路:RoPE是一种β进制编码曙光VCE | 关于matrices部分题目的解析思路(Ray老师)国家药监局关于适用《Q9(R1):质量风险管理》国际人用药品注册技术协调会指导原则的公告一文读懂 OpenAI 创始人的「世界币」一分钟读懂:量子点是个啥?纳米小颗粒彩虹战队!猪工智能这些年这些人这些事—回国散记之再访上海夏日杂诗6.24“排华法案”百年反思集会Last Encore: The Final Days of an Aging Opera Troupe | 人间 · 英文版80%代码秒生成!AI神器Copilot大升级,百万开发者动嘴编码5年内成真世界文学无法逾越的高峰,读懂他,就读懂了自己工信部印发《关于推进5G轻量化(RedCap)技术演进和应用创新发展的通知》;面向AIGC的RISC-V内核来了|AIoT情报书单 | 读懂这些,就读懂了美国十分钟读懂Diffusion:图解Diffusion扩散模型正圆投资廖茂林:实战派高人气基金经理的投资密码(市场篇)北美十大“最佳”和“最差”机场出炉,佛州西南国际机场(RSW)荣登榜首「专题速递」JPEG AI、端到端图像编码的标准化及产品落地、深度学习Denzel Washington/ FlightESC中国之声丨PCI后STEMI患者随机血糖与白蛋白比值(RAR)与造影后急性肾损伤和临床结果的关系读懂教员,就能读懂中国股市
logo
联系我们隐私协议©2024 redian.news
Redian新闻
Redian.news刊载任何文章,不代表同意其说法或描述,仅为提供更多信息,也不构成任何建议。文章信息的合法性及真实性由其作者负责,与Redian.news及其运营公司无关。欢迎投稿,如发现稿件侵权,或作者不愿在本网发表文章,请版权拥有者通知本网处理。