有的时候transformer或者torch官方的学习率调度器无法满足个人的需求,此时可以手工进行学习率调度器的编写。

利用torch.optim.lr_scheduler.LambdaLR函数针对optimizer使用函数lr_lambda构建学习率调度器。

具体代码如下:

import matplotlib.pyplot as plt
import torch.nn as nn
import torch

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(100, 10)

    def forward(self, x):
        return self.fc(x)


def get_cosine_with_min_lr(optimizer, num_warmup_steps, num_training_steps, min_lr, last_epoch=-1):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(min_lr, 0.5 * (1.0 + torch.cos(torch.pi * torch.tensor(progress, dtype=torch.float))))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)


total_steps = 10000
warmup_steps = 2000
lr = 1
min_lr = 0.5
model = Model()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = get_cosine_with_min_lr(optimizer, warmup_steps, total_steps, min_lr)


x = []
y = []
for step in range(total_steps):
    x.append(step)
    y.append(optimizer.param_groups[0]["lr"])
    scheduler.step()
plt.figure(figsize=(10, 8))
plt.xlabel("steps")
plt.ylabel("lr")
plt.plot(x, y, color="r", label="cosine")
plt.legend(loc="best")
plt.show()

绘制的lr曲线图如下:

  • 1 warmup : 前面2000 steps,lr从0线性上升到1.0。
  • 2 decay : 之后8000 steps,lr从1.0衰减到0.5。