美文网首页
pytorch nn.BatchNorm1d 与手动python

pytorch nn.BatchNorm1d 与手动python

作者: 人生一场梦163 | 来源:发表于2019-12-03 19:55 被阅读0次

由于实验需要,便用pytorch函数手动实现了batchnorm函数,但是最后发现结果不对,最后在Pytorch论坛上找到了相关解决办法!

基础

前期实现

上述博客给出了python实现代码,我将其中的numpy函数改成了pytorch的相关函数:

def fowardbn(x, gam, beta, ):
'''
x:(N,D)维数据
'''
    momentum = 0.1
    eps = 1e-05
    running_mean = 0
    running_var = 1
    running_mean = (1 - momentum) * running_mean + momentum * x.mean(dim=0)
    running_var = (1 - momentum) * running_var + momentum * x.var(dim=0)
    mean = x.mean(dim=0)
    var = x.var(dim=0)
    # bnmiddle_buffer = (input - mean) / ((var + eps) ** 0.5).data
    x_hat = (x - mean) / torch.sqrt(var + eps)
    out = gam * x_hat + beta
    cache = (x, gam, beta, x_hat, mean, var, eps)
    return out, cache

然后与nn.BatchNorm1d计算的结果比较:

model2 = nn.BatchNorm1d(5)
input1 = torch.randn(3, 5, requires_grad=True)
input2 = input1.clone().detach().requires_grad_()
x = model2(input1)

out, cache = fowardbn(input2, model2.weight, model2.bias) # 使用相同的尺度变换量

发现结果x和out的值不一样。
然后就不停的找问题是不是实现方法有差别。
\color{red}{最后}在官方论坛上找到了,有人遇到了相同的问题,官方人员给了答复,还提供了一个官方的实现版本
Pytorch的论坛做的还是挺不错的。

问题

我发现官方实现的代码中

var = input.var([0, 2, 3], unbiased=False)

在求输入的方差时,多了一个参数设置unbiased=False,不懂。
我又查看了一下Pytorch的代码文档:

torch.var(input, unbiased=True) → Tensor

Returns the variance of all elements in the input tensor.
If unbiased is False, then the variance will be calculated via the biased estimator. Otherwise, Bessel’s correction will be used.

意思是unbiased = False时,通过无偏估计计算,反之则通过贝塞尔矫正方法计算。可用如下图片总结:

image.png
这是统计方面的知识了,可以参考此博客

最终实现代码

将初始代码中方差计算加上参数unbiased = False,结果正确,完整代码如下

def fowardbn(x, gam, beta, ):
    momentum = 0.1
    eps = 1e-05
    running_mean = 0
    running_var = 1
    running_mean = (1 - momentum) * running_mean + momentum * x.mean(dim=0)
    running_var = (1 - momentum) * running_var + momentum * x.var(dim=0)
    mean = x.mean(dim=0)
    var = x.var(dim=0,unbiased=False)
    # bnmiddle_buffer = (input - mean) / ((var + eps) ** 0.5).data
    x_hat = (x - mean) / torch.sqrt(var + eps)
    out = gam * x_hat + beta
    cache = (x, gam, beta, x_hat, mean, var, eps)
    return out, cache

model2 = nn.BatchNorm1d(5)
input1 = torch.randn(3, 5, requires_grad=True)
input2 = input1.clone().detach().requires_grad_()
x = model2(input1)
out, cache = fowardbn(input2, model2.weight, model2.bias)

Reference

Batch Normalization 学习笔记
Batch Normalization梯度反向传播推导
PyTorch论坛问题
官方人员给的batchnorm2d的手动实现代码
方差的贝塞尔校正

相关文章

网友评论

      本文标题:pytorch nn.BatchNorm1d 与手动python

      本文链接:https://www.haomeiwen.com/subject/cwfhgctx.html