MoE模型已经成为大模型不可或缺的一环。
MoE模型最大的重点在于:
- 1 Router
- 2 负载均衡
今天我们先只关注Router,关注最简单的MoE模块如何实现。
最简单的MoE模块主要关注Router的计算和实际的MoE计算。
1 Router计算
noise = self.noise_linear(x)
noisy_router = torch.rand_like(noise) + F.softplus(noise)
具体代码如下:
class NoisyTopKRouter(nn.Module):
def __init__(self, hidden_size, expert_size, topk):
super(NoisyTopKRouter, self).__init__()
self.topk = topk
self.noise = nn.Linear(hidden_size, expert_size)
self.topkRouter = nn.Linear(hidden_size, expert_size)
def forward(self, x):
logits = self.topkRouter(x)
noise = self.noise(x)
noise = torch.randn_like(noise) * F.softplus(noise)
logits = logits + noise
value, index = logits.topk(self.topk, dim=-1)
zeros = torch.full_like(logits, float("-inf"))
logits = zeros.scatter(-1, index, value)
logits = F.softmax(logits, dim=-1)
return logits, index
2 MoE计算
理解MoE的计算需要转变思维,需要将思维从token角度转化为专家expert角度,找到专家i需要处理的token。
这也是为什么需要将x的batch_size和seq_len合为batch_size*seq_len的原因。
最重要的思维如下:
weighted_outputs = logits_score * expert_outputs
- 1 获取专家i处理token的mask
- 2 根据mask获取专家i处理的x的特征
- 3 根据mask获取专家i的score
- 4 score * feature
具体代码如下:
class SparseMoE(nn.Module):
def __init__(self, hidden_size, expert_size, topk):
super(SparseMoE, self).__init__()
self.topk = topk
self.router = NoisyTopKRouter(hidden_size, expert_size, topk)
self.experts = nn.ModuleList([Expert(hidden_size) for _ in range(expert_size)])
def forward(self, x):
finnal_outputs = torch.zeros_like(x) # (batch_size, seq_len, hidden_size)
logits, index = self.router(x) # (batch_size, seq_len, expert_size) (batch_size, seq_len, topk)
flat_x = x.view(-1, x.size(-1)) # (batch_size*seq_len, hidden_size)
flat_logits = logits.view(-1, logits.size(-1)) # (batch_size*seq_len, expert_size)
for i, expert in enumerate(self.experts):
# 获取专家i处理的token的位置
mask = (index == i).any(dim=-1) # (batch_size, seq_len)
flat_mask = mask.view(-1) # (batch_size*seq_len, )
# 获取专家需要处理的token feature
expert_inputs = flat_x[flat_mask] # (token_chosen, hidden_size)
expert_outputs = expert(expert_inputs) # (token_chosen, hidden_size)
# 获取token的score
logits_score = flat_logits[flat_mask, i].unsqueeze(-1) # (token_chosen, 1)
# token_score * token_feature
weighted_outputs = logits_score * expert_outputs # (token_chosen, hidden_size)
finnal_outputs[mask] += weighted_outputs
return finnal_outputs