美文网首页
Pytorch计算Normalize的数值

Pytorch计算Normalize的数值

作者: zeolite | 来源:发表于2021-06-15 21:48 被阅读0次

以MNIST为例, 计算Normalize的均值和方差

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

batch_size = 1

transform = transforms.Compose([
    transforms.ToTensor(),
])

train_data = datasets.MNIST(
    root='.',
    train=True,
    download=True,
    transform=transform
)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

train_len = len(train_data)

means = torch.zeros((train_len,), dtype=torch.float32)
stds = torch.zeros((train_len,), dtype=torch.float32)
for idx, (image, label) in enumerate(train_loader):
    image = torch.squeeze(image)

    means[idx] = image.mean()
    stds[idx] = image.std()

means = means.mean()
stds = stds.mean()
print(means, stds)

相关文章

网友评论

      本文标题:Pytorch计算Normalize的数值

      本文链接:https://www.haomeiwen.com/subject/judtyltx.html