0%

Pytorch中的Batchnorm踩坑记录[BUG fix]

在西门子实习的时候,当时突发奇想想设计一个pre-training加post-training的架构,当时是希望在后面接一个projection head,然后前面用预训练好的模型冻结住不改变就可以了。原本以为只需要写一串这样的代码就完全OK:

1
2
for param in model.xxxpart.parameters():
param.require_grad = False

结果训练完后面的projection head,前面的模型输出的重建图片的结果和原来差别巨大,我就感觉并没有冻结住所有的变量,肯定是对前面也进行了训练。查看训练前后模型的参数变化:

1
2
3
4
5
6
7
8
model_dict = model.state_dict()
before_training_params = {k: v.clone() for k, v in ckpt_state_dict.items() if k in model_dict and model_dict[k].shape == v.shape}
After_training_params = {k: v.clone() for k, v in model.state_dict().items() if k in before_training_params and model_dict[k].shape == v.shape}

for k, v in before_training_params.items():
# check if the forzen parameters are changed or not
if not torch.equal(v, After_training_params[k]):
print(f"Parameter {k} is changed after training")

输出的结果如下:

1
2
3
4
5
6
7
Parameter inc.double_conv.1.running_mean is changed after training
Parameter inc.double_conv.1.running_var is changed after training
Parameter inc.double_conv.1.num_batches_tracked is changed after training
Parameter inc.double_conv.4.weight is not changed after training
Parameter inc.double_conv.5.weight is not changed after training
Parameter inc.double_conv.5.bias is not changed after training
...

发现与BatchNorm有关的地方weight和bias被冻结了,running_mean / running_var / num_batches_tracked 这三个变量发生了改变,于是询问GPT老师这三个到底是什么样的含义,为什么已经设定了 require_grad = False 还是会发生变化。

running_meanrunning_var 两个变量记录了在训练过程中输入数据的均值和方差的统计量。在训练模式(model.train())下,BatchNorm 层会根据当前批次的数据更新这些统计量(通过动量平均),即使对应的 weight 和 bias 参数被冻结(requires_grad = False)。

num_batches_tracked 记录了训练期间处理的批次数量,用于计算动量平均。它本身不会直接影响模型的输出,但它的变化表明 BatchNorm 的统计量更新仍在进行。

所以我没有关闭这些统计量的更新,导致了训练前后重建效果差异巨大。为了防止这些统计量的变化,可以采用以下的两种方案:

  • 在模型的 forward 函数里面冻结部分不计算梯度来不更新统计量:
    1
    2
    3
    4
    5
    6
    class myModel(nn.Module):
    def forward(self, x):
    with torch.no_grad():
    x = self.fc1(x)
    ...
    output = self.outoput_mask(x)
  • 手动禁用 BatchNorm 的统计量更新:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    def freeze_bn(model):
    for module in model.modules():
    if isinstance(module, nn.BatchNorm2d):
    module.trainable = False # 禁用训练模式
    module.eval() # 强制使用预训练统计量
    for param in model.parameters():
    param.requires_grad = False
    for param in model.output_maskes.parameters():
    param.requires_grad = True
    # 在冻结参数后调用
    freeze_bn(model)

到此就可以完全彻底的冻结这些层了,本质上还是因为 BatchNorm 它维护了两类参数/缓冲区:

  • 可训练参数(由优化器更新):
    weight(缩放因子,γ),
    bias(偏移因子,β)
    这些参数可以通过 requires_grad = False 冻结,防止优化器更新。

  • 缓冲区(不由优化器更新,而是由前向传播更新):
    running_mean(移动平均均值)
    running_var(移动平均方差)
    num_batches_tracked(批次计数器)
    这些缓冲区是在训练模式(module.train())下,通过前向传播根据当前批次的数据自动更新的。它们的更新逻辑与 requires_grad 无关,因此即使冻结了所有参数(requires_grad = False),它们仍然会在训练时发生变化。

至此,bug改好了😅