LoRA是现在大红大紫的微调方法,其原理也是十分简单,除了依赖DeepSpeed或者Peft这些库之外,自己手写也是不错的选择。

import torch
import torch.nn as nn
from transformers import AutoTokenizer, BertModel

class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, lora_r, lora_alpha):
        super().__init__()
        self.lora_r = lora_r
        self.lora_alpha = lora_alpha
        self.B = nn.Parameter(torch.zeros(size=(in_dim, lora_r)), requires_grad=True)
        self.A = nn.Parameter(torch.rand(size=(lora_r, out_dim)), requires_grad=True)

    def forward(self, x):
        return (self.lora_alpha / self.lora_r) * (x @ self.B @ self.A)

此处我们定义了LoRA层,接下来我们需要将LoRA层引入Linear层。

class LinearWithLoRA(nn.Module):
    def __init__(self, linear, lora_r, lora_alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(linear.in_features, linear.out_features, lora_r, lora_alpha)

    def forward(self, x):
        return self.linear(x) + self.lora(x)
最后我们使用Bert来尝试LoRA微调,需要注意的是如何将Bert的Linear层替换为LinearWithLoRA层。
  • 1 首先通过getattr获取Linear层的前一层的name和module
  • 2 使用setattr进行替换
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = BertModel.from_pretrained(r"F:\Bert", local_files_only=True)

    def print_model(self):
        print(self.model)

    def set_layer(self, name, module):
        # find module's parent layer and name
        split_name = name.split(".")
        layer = self.model
        for i in split_name[:-1]:
            layer = layer.__getattr__(i)
        father_name = split_name[-1]
        father_layer = layer
        # setattr
        father_layer.__setattr__(father_name, module)

    def insert_lora(self, lora_r, lora_alpha):
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Linear):
                new_module = LinearWithLoRA(module, lora_r, lora_alpha)
                self.set_layer(name, new_module)

    def merge_lora(self):
        for name, module in self.model.named_modules():
            if not isinstance(module, LinearWithLoRA):
                continue
            linear = module.linear
            linear.weight.data += (module.lora.lora_alpha / module.lora.lora_r) * (module.lora.B @ module.lora.A).t()
            self.set_layer(name, linear)


if __name__ == "__main__":
    model = Model()
    print(f"===================origin model====================")
    model.print_model()
    model.insert_lora(8, 16)
    print(f"=================== with lora model====================")
    model.print_model()
    print(f"=================== merge lora model====================")
    model.merge_lora()
    model.print_model()

insert_lora函数,使用LinearWithLoRA模块替换Linear模块。

merge_lora函数,使用Linear模块替换LinearWithLoRA模块。

原生Bert的model架构如下:  
BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(21128, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (pooler): BertPooler(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh()
  )
)
插入LoRA层后的Bert模型架构:  
BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(21128, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): LinearWithLoRA(
              (linear): Linear(in_features=768, out_features=768, bias=True)
              (lora): LoRALayer()
            )
            (key): LinearWithLoRA(
              (linear): Linear(in_features=768, out_features=768, bias=True)
              (lora): LoRALayer()
            )
            (value): LinearWithLoRA(
              (linear): Linear(in_features=768, out_features=768, bias=True)
              (lora): LoRALayer()
            )
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): LinearWithLoRA(
              (linear): Linear(in_features=768, out_features=768, bias=True)
              (lora): LoRALayer()
            )
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): LinearWithLoRA(
            (linear): Linear(in_features=768, out_features=3072, bias=True)
            (lora): LoRALayer()
          )
          (intermediate_act_fn): GELUActivation()
        )
        (output): BertOutput(
          (dense): LinearWithLoRA(
            (linear): Linear(in_features=3072, out_features=768, bias=True)
            (lora): LoRALayer()
          )
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (pooler): BertPooler(
    (dense): LinearWithLoRA(
      (linear): Linear(in_features=768, out_features=768, bias=True)
      (lora): LoRALayer()
    )
    (activation): Tanh()
  )
)