在mnist手写字符识别的pytorch代码中,回有如下代码:
train_transforms = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
其中transforms.Normalize((0.1307,), (0.3081,))
表示进行标准化,0.1307和0.3081分别是mnist数据集的均值和标准差。
网友评论