先贴上PyTorch官网上的关于BatchNorm的公式:
这个BatchNorm到底怎么freeze?这个函数究竟是如何成为广大网友心中的大坑的?看了好几天源码和相关的博客,我似乎有点明白了。 本文主要内容是_BatchNorm
相关的源码简介。同样基于PyTorch 1.1.0。
参数 1 2 3 4 5 6 def __init__ (self, num_features, eps=1e-5 , momentum=0.1 , affine=True, track_running_stats=True )
这是_BatchNorm
的初始化的参数,无论是1d2d3d的BatchNorm都是继承自这个类,所以只需要看这个就可以了。Doc
这里边很重要的两个参数就是affine
和track_running_stats
了。 跟这两个相关的代码如下:
1 2 3 4 5 6 7 8 9 10 11 if self.affine: self.weight = Parameter(torch.Tensor(num_features)) self.bias = Parameter(torch.Tensor(num_features)) else : ... if self.track_running_stats: self.register_buffer('running_mean' , torch.zeros(num_features)) self.register_buffer('running_var' , torch.ones(num_features)) self.register_buffer('num_batches_tracked' , torch.tensor(0 , dtype=torch.long)) else : ...
可以看到的是,affine
是和weight和bias相关的,也就是公式中的$\gamma$和$\beta$。 而track_running_stats
是和运行时的均值和方差有关的。
affine 这个参数和公式中的$\gamma$和$\beta$相关,是学习到的变量,也就是说,是通过反向传播 学习到的。这里有一个小例子看一下:
1 2 3 4 5 6 7 8 9 10 11 norm = nn.BatchNorm1d(5 ) print_norm(norm) a = torch.randn(2 , 5 ) b = norm(a).sum() b.backward() print_norm(norm) optimizer = torch.optim.SGD(norm.parameters(), lr=1e5 ) optimizer.step() print_norm(norm)
因为weight的梯度较小,所以我把lr设置的比较大。下边是结果:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 norm weight Parameter containing: tensor([0.7289, 0.4699, 0.3586, 0.6225, 0.5753], requires_grad=True) grad None norm bias Parameter containing: tensor([0., 0., 0., 0., 0.], requires_grad=True) grad None ******************** norm weight Parameter containing: tensor([0.7289, 0.4699, 0.3586, 0.6225, 0.5753], requires_grad=True) grad tensor([ 4.7068e-07, 0.0000e+00, 0.0000e+00, 0.0000e+00, -9.3044e-08]) norm bias Parameter containing: tensor([0., 0., 0., 0., 0.], requires_grad=True) grad tensor([2., 2., 2., 2., 2.]) ******************** norm weight Parameter containing: tensor([0.6819, 0.4699, 0.3586, 0.6225, 0.5846], requires_grad=True) grad tensor([ 4.7068e-07, 0.0000e+00, 0.0000e+00, 0.0000e+00, -9.3044e-08]) norm bias Parameter containing: tensor([-200000., -200000., -200000., -200000., -200000.], requires_grad=True) grad tensor([2., 2., 2., 2., 2.])
可以看到,在step后,weight和bias被更新了上去。因为这两个参数是学习得到的,所以freeze时就显得很简单了,使用通常的方法就可以:
1 2 for para in model.parameters(): para.requires_grad = False
然后在SGD的时候还得过滤一下:
1 2 optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01 , momentum=0.9 , weight_decay=1e-4 )
这些方法百度里都可以找到。
在这里得到一个小结论,也就是requires_grad
、affine
、weight
、bias
这几个概念是一组相关的概念。
track_running_stats 这个参数在这里是比较迷 的,因为这个参数是运行时的统计信息,不是反向传播学到的。 这个参数通常和traing是联系在一起的。通常情况下,大家说的model.eval()
会不使用BatchNorm和Dropout,这是怎么回事?看一下源码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 @weak_script_method def forward (self, input ): self._check_input_dim(input) if self.momentum is None : exponential_average_factor = 0.0 else : exponential_average_factor = self.momentum if self.training and self.track_running_stats: if self.num_batches_tracked is not None : self.num_batches_tracked += 1 if self.momentum is None : exponential_average_factor = 1.0 / float(self.num_batches_tracked) else : exponential_average_factor = self.momentum return F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, self.training or not self.track_running_stats, exponential_average_factor, self.eps)
这里是完整的forward代码,我没有做截取。截止到PyTorch 1.3.1这个函数还是这个样子。可以看到,这两个参数相关的地方就是self.training and self.track_running_stats
和self.training or not self.track_running_stats
两处。 先看第一处,它是对更新时的参数的系数做了一个计算。
第二处可能也是一个比较重要的,在传递到下一个函数进行处理时,它竟然直接把两个参数合并了?难怪这个函数成为广大网友心中的大坑 。 文章[1]对这两个参数做了一些探究,写的不错,他其中引用了一个[2]中的内容,我也直接引用一下:
training=True
, track_running_stats=True
, 这是常用的training时期待的行为,running_mean
和running_var
会跟踪不同batch数据的mean和variance,但是仍然是用每个batch的mean和variance做normalization。
training=True
, track_running_stats=Fals
e, 这时候running_mean
和running_var
不跟踪跨batch数据的statistics了,但仍然用每个batch的mean和variance做normalization。
training=False
, track_running_stats=True
, 这是我们期待的test时候的行为,即使用training阶段估计的running_mean
和running_var
.
training=False
, track_running_stats=False
,同2(!!!).
很明显的是,或操作的结果必定是3个true一个false,对应到这个例子里,也就是只有(3)传入到F.batch_nrom
的参数才是false !另外的三个都是True啊有没有!尤其是(4)竟然变成了True?? 因为这个F里边的函数继续调用我就找不到在哪了,所以没法直接看源码了,只能分别对上述四个情况做几个小实验,测试一下。
为了控制结果一致,设置一个固定的weight值和a
1 2 3 4 5 weight = torch.tensor([0.8 , 0.6 , 0.5 , 0.4 ], requires_grad=True ) norm = nn.BatchNorm1d(4 , ...) norm.weight = Parameter(weight) a = torch.Tensor([[1 , 9 , 7 , 3 ], [2 , 8 , 6 , 4 ]])
condition 1 两个都为True时,结果似乎是最简单的。
1 2 3 4 5 norm = nn.BatchNorm1d(4 ) print_norm_mean_var(norm) norm(a) print_norm_mean_var(norm)
结果显而易见,它学到了当前这个batch的均值和方差,这也是最基本的情况:
1 2 3 4 5 6 norm mean tensor([0., 0., 0., 0.]) norm var tensor([1., 1., 1., 1.]) ******************** norm mean tensor([0.1500, 0.8500, 0.6500, 0.3500]) norm var tensor([0.9500, 0.9500, 0.9500, 0.9500]) ********************
condition 2 把上述代码的第一行改成norm = nn.BatchNorm1d(4, track_running_stats=False)
时,输出的全部都是None,因为初始化的时候这部分值就已经设置成了None。即便初始化以后再执行norm.track_running_stats = True
,结果也还是None,显然当这两个参数为None时后续也不会有什么操作会更改mean和var。
但是!
如果程序运行起来再改成False,那么结果会怎么样?
1 2 3 4 5 6 norm = nn.BatchNorm1d(4 , track_running_stats=True ) print_norm_mean_var(norm) norm.track_running_stats = False norm(a) print_norm_mean_var(norm)
结果如下:
1 2 3 4 5 6 norm mean tensor([0., 0., 0., 0.]) norm var tensor([1., 1., 1., 1.]) ******************** norm mean tensor([0.1500, 0.8500, 0.6500, 0.3500]) norm var tensor([0.9500, 0.9500, 0.9500, 0.9500]) ********************
它还是学到了!因为传入到F.batch_nrom
的参数肯定是True,看起来,唯一能决定它学不学的,应该就是mean和var是不是None 。
condition 3 到这个情况时,问题就复杂多了!
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 print_norm_mean_var(norm) b = norm(a) print('b =' ,b) print_norm_mean_var(norm) norm.eval() b = norm(a) print('b =' ,b) print_norm_mean_var(norm) norm.train() b = norm(a) print('b =' ,b) print_norm_mean_var(norm)
看一下结果:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 norm mean tensor([0., 0., 0., 0.]) norm var tensor([1., 1., 1., 1.]) ******************** b = tensor([[-0.8000, 0.6000, 0.5000, -0.4000], [ 0.8000, -0.6000, -0.5000, 0.4000]], grad_fn=<NativeBatchNormBackward>) norm mean tensor([0.1500, 0.8500, 0.6500, 0.3500]) norm var tensor([0.9500, 0.9500, 0.9500, 0.9500]) ******************** b = tensor([[0.6977, 5.0170, 3.2575, 1.0875], [1.5184, 4.4014, 2.7445, 1.4979]], grad_fn=<NativeBatchNormBackward>) norm mean tensor([0.1500, 0.8500, 0.6500, 0.3500]) norm var tensor([0.9500, 0.9500, 0.9500, 0.9500]) ******************** b = tensor([[-0.8000, 0.6000, 0.5000, -0.4000], [ 0.8000, -0.6000, -0.5000, 0.4000]], grad_fn=<NativeBatchNormBackward>) norm mean tensor([0.2850, 1.6150, 1.2350, 0.6650]) norm var tensor([0.9050, 0.9050, 0.9050, 0.9050]) ********************
第一次输出的时候是默认的。 第二次输出的是condition 1的情况,即学到了当前这个batch的mean和var,b也被batchnorm了。 第三次的结果是traning=False
时,很明显,这一次它没有继续学习,因为这是condition 3的情况,也就是self.training or not self.track_running_stats == False
。b被batchnorm后的结果发生了变化,是因为这一次用的是上一次学到的mean和var,而不是当前batch的,所以结果发生了变化。 第四次的结果是最让我意外的,因为开启了train,所以它肯定会继续学习,又回到condition 1的情况,但是,b的值似乎和第二次输出又变的一样了?
在我看来,batchnrom在train模式下总是使用当前的batch的mean和var 进行batchnorm,然后学习 当前batch的分布,用于更新running_mean
和running_var
。当切换到eval模式的时候,即使用学习到的running_mean
和running_var
进行batchnorm。所以我们的batchsize应该大一些,且每次取到的batch尽量随机,这样才能不断地学习到整个数据集的分布。
condition 4 如果初始化就把track_running_stats = False
,那么无论是train还是eval,running_mean
和running_var
都是None,且每次只是用当前batch的mean和var 。这里也就是[2]中所说的为什么这个同condition 2。
但是如果开始时true,改成False呢? 因为代码和condition 3完全相同,只是添加了一行norm.track_running_stats = False
,看一下结果的神奇之处:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 norm mean tensor([0., 0., 0., 0.]) norm var tensor([1., 1., 1., 1.]) ******************** b = tensor([[-0.8000, 0.6000, 0.5000, -0.4000], [ 0.8000, -0.6000, -0.5000, 0.4000]], grad_fn=<NativeBatchNormBackward>) norm mean tensor([0.1500, 0.8500, 0.6500, 0.3500]) norm var tensor([0.9500, 0.9500, 0.9500, 0.9500]) ******************** b = tensor([[-0.8000, 0.6000, 0.5000, -0.4000], [ 0.8000, -0.6000, -0.5000, 0.4000]], grad_fn=<NativeBatchNormBackward>) norm mean tensor([0.2850, 1.6150, 1.2350, 0.6650]) norm var tensor([0.9050, 0.9050, 0.9050, 0.9050]) ******************** b = tensor([[-0.8000, 0.6000, 0.5000, -0.4000], [ 0.8000, -0.6000, -0.5000, 0.4000]], grad_fn=<NativeBatchNormBackward>) norm mean tensor([0.4065, 2.3035, 1.7615, 0.9485]) norm var tensor([0.8645, 0.8645, 0.8645, 0.8645]) ********************
首先是,它还是学习了mean和var。这和condition 2结尾得到的结论完全相同。另外,在第二次输出的时候,它依然使用的当前样本的方差,而且继续进行了学习,因为这个的第二次输出和condition 3的第三次输出完全相同 。这很明显了,因为self.training or not self.track_running_stats == True
。
freeze 有了上边的介绍,相信应该已经基本清楚了,最好清楚代码处于什么位置、要什么功能 ,否则最好不要改track_running_stats
这个属性。
在freeze的时候我看到了两种不同的说法,有一种是直接设momentum = 0
或为None
,这个方法看上去使得它不会学习新的分布了,但是似乎存在一点小问题,就是在fine tune的时候,他依然使用的是当前batch的mean和var,而到了测试集的时候,它却使用的是预训练数据集的mean和var,显得有一些诡异。
另外一个方法就是使用eval固定,这个方法可能看起来是最稳妥的,即是condition 3,直接使用从预训练的数据集上训练好的mean和var,[3]已经给出了比较详细的答案,在这里就不多做赘述。我想说的是,如果需要固定的部分没有dropout的话,似乎不需要使用apply去遍历,因为本身在执行eval的时候就会遍历这个结点的所有children,这和apply几乎是相似的。
当然,上述两种方法对与一个模型究竟有什么影响,甚至是否需要单独freeze batchnorm层,都应该具体问题具体分析,还是那句话,应该要清楚需要什么,才能选择适合的方法。 我也是初学者,如果有什么问题或者错误的地方,欢迎联系我~
reference [1] Pytorch的BatchNorm层使用中容易出现的问题 [2] BatchNorm2d增加的参数track_running_stats如何理解? - 李韶华的回答 - 知乎 [3] Pytorch中的Batch Normalization layer踩坑