MoE模型已经成为大模型不可或缺的一环。

MoE模型最大的重点在于:

  • 1 Router
  • 2 负载均衡

今天我们先只关注Router,关注最简单的MoE模块如何实现。

最简单的MoE模块主要关注Router的计算和实际的MoE计算。

1 Router计算

noise = self.noise_linear(x)

noisy_router = torch.rand_like(noise) + F.softplus(noise)

具体代码如下:

class NoisyTopKRouter(nn.Module):
    def __init__(self, hidden_size, expert_size, topk):
        super(NoisyTopKRouter, self).__init__()
        self.topk = topk
        self.noise = nn.Linear(hidden_size, expert_size)
        self.topkRouter = nn.Linear(hidden_size, expert_size)

    def forward(self, x):
        logits = self.topkRouter(x)
        noise = self.noise(x)
        noise = torch.randn_like(noise) * F.softplus(noise)
        logits = logits + noise

        value, index = logits.topk(self.topk, dim=-1)
        zeros = torch.full_like(logits, float("-inf"))
        logits = zeros.scatter(-1, index, value)
        logits = F.softmax(logits, dim=-1)
        return logits, index

2 MoE计算

理解MoE的计算需要转变思维,需要将思维从token角度转化为专家expert角度,找到专家i需要处理的token。

这也是为什么需要将x的batch_size和seq_len合为batch_size*seq_len的原因。

最重要的思维如下:

weighted_outputs = logits_score * expert_outputs
  • 1 获取专家i处理token的mask
  • 2 根据mask获取专家i处理的x的特征
  • 3 根据mask获取专家i的score
  • 4 score * feature

具体代码如下:

class SparseMoE(nn.Module):
    def __init__(self, hidden_size, expert_size, topk):
        super(SparseMoE, self).__init__()
        self.topk = topk
        self.router = NoisyTopKRouter(hidden_size, expert_size, topk)
        self.experts = nn.ModuleList([Expert(hidden_size) for _ in range(expert_size)])

    def forward(self, x):
        finnal_outputs = torch.zeros_like(x)  # (batch_size, seq_len, hidden_size)
        logits, index = self.router(x)  # (batch_size, seq_len, expert_size) (batch_size, seq_len, topk)
        flat_x = x.view(-1, x.size(-1))  # (batch_size*seq_len, hidden_size)
        flat_logits = logits.view(-1, logits.size(-1))  # (batch_size*seq_len, expert_size)
        for i, expert in enumerate(self.experts):
            # 获取专家i处理的token的位置
            mask = (index == i).any(dim=-1)  # (batch_size, seq_len)
            flat_mask = mask.view(-1)  # (batch_size*seq_len, )
            # 获取专家需要处理的token feature
            expert_inputs = flat_x[flat_mask]  # (token_chosen, hidden_size)
            expert_outputs = expert(expert_inputs)  # (token_chosen, hidden_size)
            # 获取token的score
            logits_score = flat_logits[flat_mask, i].unsqueeze(-1)  # (token_chosen, 1)
            # token_score * token_feature
            weighted_outputs = logits_score * expert_outputs # (token_chosen, hidden_size)
            finnal_outputs[mask] += weighted_outputs
        return finnal_outputs

Related Post