有的时候transformer或者torch官方的学习率调度器无法满足个人的需求,此时可以手工进行学习率调度器的编写。
利用torch.optim.lr_scheduler.LambdaLR函数针对optimizer使用函数lr_lambda构建学习率调度器。
具体代码如下:
import matplotlib.pyplot as plt
import torch.nn as nn
import torch
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(100, 10)
def forward(self, x):
return self.fc(x)
def get_cosine_with_min_lr(optimizer, num_warmup_steps, num_training_steps, min_lr, last_epoch=-1):
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
return max(min_lr, 0.5 * (1.0 + torch.cos(torch.pi * torch.tensor(progress, dtype=torch.float))))
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)
total_steps = 10000
warmup_steps = 2000
lr = 1
min_lr = 0.5
model = Model()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = get_cosine_with_min_lr(optimizer, warmup_steps, total_steps, min_lr)
x = []
y = []
for step in range(total_steps):
x.append(step)
y.append(optimizer.param_groups[0]["lr"])
scheduler.step()
plt.figure(figsize=(10, 8))
plt.xlabel("steps")
plt.ylabel("lr")
plt.plot(x, y, color="r", label="cosine")
plt.legend(loc="best")
plt.show()
绘制的lr曲线图如下:
- 1 warmup : 前面2000 steps,lr从0线性上升到1.0。
- 2 decay : 之后8000 steps,lr从1.0衰减到0.5。