RoPE旋转编码已经成为大模型的基础建设,RoPE区别于传统的绝对位置编码,为q,k向量左乘一个旋转变量,将q,k的绝对位置转换为相对位置,使不同位置的q,k进行计算后拥有相对位置信息。

绝对位置编码:(seq_len, hidden_size)

theta的分子为seq_len,分母由hidden_size控制。

同一个位置的位置编码的角度越来越小。

Self-Attention计算公式:

我们希望计算Slef-Attention时,q,k拥有绝对位置,q^T k的计算可以包含相对位置信息,故我们定义一个可以实现的范式:

我们找到一个符合上述范式的函数g():

fq(xm, m)在2d状态下:

fq(xm, m)在N维度状态下:

fq(xm, m)在N维度状态下简便运算:

故位置i和位置j的Attention,除了本身i,j的向量信息之外,还包含了相对位置信息。

1 LLaMa实现

一般RoPE的实现都分为两步:

  • 1 生成不同seq位置不同维度的theta,(seq_len, hidden_size//2)
  • 2 将对应的x与theta按照RoPE的简便计算方法进行相乘。

LLaMa的第2步利用了虚数的计算方法。

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

2 ChatGLM3 实现

chatglm使纯粹的按照RoPE论文中的计算方式进行计算。

def forward_impl(
        self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
):
    # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
    theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))

    # Create position indexes `[0, 1, ..., seq_len - 1]`
    seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)

    # Calculate the product of position index and $\theta_i$
    idx_theta = torch.outer(seq_idx, theta).float()

    cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)

    # this is to mimic the behaviour of complex32, else we will get different results
    if dtype in (torch.float16, torch.bfloat16, torch.int8):
        cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
    return cache

def forward(self, max_seq_len, offset=0):
    return self.forward_impl(
        max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
    )


def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
    # x: [sq, b, np, hn]
    sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
    rot_dim = rope_cache.shape[-2] * 2
    x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
    # truncate to support variable sizes
    rope_cache = rope_cache[:sq]
    xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
    rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
    x_out2 = torch.stack(
        [
            xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
            xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
        ],
        -1,
    )
    x_out2 = x_out2.flatten(3)
    return torch.cat((x_out2, x_pass), dim=-1)