在西门子实习的时候,当时突发奇想想设计一个pre-training加post-training的架构,当时是希望在后面接一个projection head,然后前面用预训练好的模型冻结住不改变就可以了。原本以为只需要写一串这样的代码就完全OK:
1 | for param in model.xxxpart.parameters(): |
结果训练完后面的projection head,前面的模型输出的重建图片的结果和原来差别巨大,我就感觉并没有冻结住所有的变量,肯定是对前面也进行了训练。查看训练前后模型的参数变化:
1 | model_dict = model.state_dict() |
输出的结果如下:1
2
3
4
5
6
7Parameter inc.double_conv.1.running_mean is changed after training
Parameter inc.double_conv.1.running_var is changed after training
Parameter inc.double_conv.1.num_batches_tracked is changed after training
Parameter inc.double_conv.4.weight is not changed after training
Parameter inc.double_conv.5.weight is not changed after training
Parameter inc.double_conv.5.bias is not changed after training
...
发现与BatchNorm
有关的地方weight和bias被冻结了,running_mean
/ running_var
/ num_batches_tracked
这三个变量发生了改变,于是询问GPT老师这三个到底是什么样的含义,为什么已经设定了 require_grad = False
还是会发生变化。
running_mean
与 running_var
两个变量记录了在训练过程中输入数据的均值和方差的统计量。在训练模式(model.train()
)下,BatchNorm 层会根据当前批次的数据更新这些统计量(通过动量平均),即使对应的 weight 和 bias 参数被冻结(requires_grad = False
)。
num_batches_tracked
记录了训练期间处理的批次数量,用于计算动量平均。它本身不会直接影响模型的输出,但它的变化表明 BatchNorm 的统计量更新仍在进行。
所以我没有关闭这些统计量的更新,导致了训练前后重建效果差异巨大。为了防止这些统计量的变化,可以采用以下的两种方案:
- 在模型的
forward
函数里面冻结部分不计算梯度来不更新统计量:1
2
3
4
5
6class myModel(nn.Module):
def forward(self, x):
with torch.no_grad():
x = self.fc1(x)
...
output = self.outoput_mask(x) - 手动禁用
BatchNorm
的统计量更新:1
2
3
4
5
6
7
8
9
10
11def freeze_bn(model):
for module in model.modules():
if isinstance(module, nn.BatchNorm2d):
module.trainable = False # 禁用训练模式
module.eval() # 强制使用预训练统计量
for param in model.parameters():
param.requires_grad = False
for param in model.output_maskes.parameters():
param.requires_grad = True
# 在冻结参数后调用
freeze_bn(model)
到此就可以完全彻底的冻结这些层了,本质上还是因为 BatchNorm
它维护了两类参数/缓冲区:
可训练参数(由优化器更新):
weight(缩放因子,γ),
bias(偏移因子,β)
这些参数可以通过requires_grad = False
冻结,防止优化器更新。缓冲区(不由优化器更新,而是由前向传播更新):
running_mean(移动平均均值)
,running_var(移动平均方差)
,num_batches_tracked(批次计数器)
,
这些缓冲区是在训练模式(module.train()
)下,通过前向传播根据当前批次的数据自动更新的。它们的更新逻辑与requires_grad
无关,因此即使冻结了所有参数(requires_grad = False
),它们仍然会在训练时发生变化。
至此,bug改好了😅