Batch Normalization是目前深度学习中广泛应用的技术,其出现来源于“协变量偏移”。
“协变量偏移”的定义为训练过程中网络参数发生变化导致的网络激活分布发生变化的现象。
由于模型激活分布变化,模型就需要继续训练以适应这种变化。
Batch Normalization应允而生,通过BatchNorm的方法,使输入数据的分布变得更稳定,这样就不需要太多调整模型参数,这样就可以提高训练的效率,加速收敛。
Batch Normalization大概分为两个步骤:
- 1 针对输入变量的每一个特征进行Norm,使其期望为0,方差为1.
- 2 为了使分布靠齐激活值的原始分布,设置了可学习的scale gamma和偏置值beta。
Batch Normalization针对训练阶段和推理阶段:
- 均值和方差的计算:
- 训练阶段:在训练过程中,BatchNorm会使用当前批次(batch)的均值和方差来进行归一化。这意味着每个批次的数据会根据自身的统计量进行标准化。
- 推理阶段:在推理(测试)阶段,BatchNorm使用在训练过程中累积的全局均值和方差。这些全局统计量是通过对多个批次的数据进行移动平均(Moving Average)计算得到的。
- 归一化公式:
- 训练阶段:使用当前批次的均值和方差:
- 推理阶段:使用训练过程中累积的全局均值和方差:
- 参数更新:
- 训练阶段:在训练过程中,BatchNorm层不仅在每个批次上计算均值和方差,还会更新全局均值和方差的移动平均值。
- 推理阶段:在推理过程中,不会再更新这些统计量,只是使用训练阶段已经计算好的全局均值和方差。
- 模式切换:
- 在实现中,框架通常会提供一个模式切换的机制,比如在PyTorch中,使用
model.eval()
将模型切换到推理模式,使用model.train()
将模型切换到训练模式。这些操作会影响BatchNorm层的行为。
- 在实现中,框架通常会提供一个模式切换的机制,比如在PyTorch中,使用