LoRA是现在大红大紫的微调方法,其原理也是十分简单,除了依赖DeepSpeed或者Peft这些库之外,自己手写也是不错的选择。
import torch
import torch.nn as nn
from transformers import AutoTokenizer, BertModel
class LoRALayer(nn.Module):
def __init__(self, in_dim, out_dim, lora_r, lora_alpha):
super().__init__()
self.lora_r = lora_r
self.lora_alpha = lora_alpha
self.B = nn.Parameter(torch.zeros(size=(in_dim, lora_r)), requires_grad=True)
self.A = nn.Parameter(torch.rand(size=(lora_r, out_dim)), requires_grad=True)
def forward(self, x):
return (self.lora_alpha / self.lora_r) * (x @ self.B @ self.A)
此处我们定义了LoRA层,接下来我们需要将LoRA层引入Linear层。
class LinearWithLoRA(nn.Module):
def __init__(self, linear, lora_r, lora_alpha):
super().__init__()
self.linear = linear
self.lora = LoRALayer(linear.in_features, linear.out_features, lora_r, lora_alpha)
def forward(self, x):
return self.linear(x) + self.lora(x)
最后我们使用Bert来尝试LoRA微调,需要注意的是如何将Bert的Linear层替换为LinearWithLoRA层。
- 1 首先通过getattr获取Linear层的前一层的name和module
- 2 使用setattr进行替换
class Model(nn.Module):
def __init__(self):
super().__init__()
self.model = BertModel.from_pretrained(r"F:\Bert", local_files_only=True)
def print_model(self):
print(self.model)
def set_layer(self, name, module):
# find module's parent layer and name
split_name = name.split(".")
layer = self.model
for i in split_name[:-1]:
layer = layer.__getattr__(i)
father_name = split_name[-1]
father_layer = layer
# setattr
father_layer.__setattr__(father_name, module)
def insert_lora(self, lora_r, lora_alpha):
for name, module in self.model.named_modules():
if isinstance(module, nn.Linear):
new_module = LinearWithLoRA(module, lora_r, lora_alpha)
self.set_layer(name, new_module)
def merge_lora(self):
for name, module in self.model.named_modules():
if not isinstance(module, LinearWithLoRA):
continue
linear = module.linear
linear.weight.data += (module.lora.lora_alpha / module.lora.lora_r) * (module.lora.B @ module.lora.A).t()
self.set_layer(name, linear)
if __name__ == "__main__":
model = Model()
print(f"===================origin model====================")
model.print_model()
model.insert_lora(8, 16)
print(f"=================== with lora model====================")
model.print_model()
print(f"=================== merge lora model====================")
model.merge_lora()
model.print_model()
insert_lora函数,使用LinearWithLoRA模块替换Linear模块。
merge_lora函数,使用Linear模块替换LinearWithLoRA模块。
原生Bert的model架构如下:
BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(21128, 768, padding_idx=0)
(position_embeddings): Embedding(512, 768)
(token_type_embeddings): Embedding(2, 768)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0-11): 12 x BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
插入LoRA层后的Bert模型架构:
BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(21128, 768, padding_idx=0)
(position_embeddings): Embedding(512, 768)
(token_type_embeddings): Embedding(2, 768)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0-11): 12 x BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): LinearWithLoRA(
(linear): Linear(in_features=768, out_features=768, bias=True)
(lora): LoRALayer()
)
(key): LinearWithLoRA(
(linear): Linear(in_features=768, out_features=768, bias=True)
(lora): LoRALayer()
)
(value): LinearWithLoRA(
(linear): Linear(in_features=768, out_features=768, bias=True)
(lora): LoRALayer()
)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): LinearWithLoRA(
(linear): Linear(in_features=768, out_features=768, bias=True)
(lora): LoRALayer()
)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): LinearWithLoRA(
(linear): Linear(in_features=768, out_features=3072, bias=True)
(lora): LoRALayer()
)
(intermediate_act_fn): GELUActivation()
)
(output): BertOutput(
(dense): LinearWithLoRA(
(linear): Linear(in_features=3072, out_features=768, bias=True)
(lora): LoRALayer()
)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): LinearWithLoRA(
(linear): Linear(in_features=768, out_features=768, bias=True)
(lora): LoRALayer()
)
(activation): Tanh()
)
)