以MNIST为例, 计算Normalize的均值和方差
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
batch_size = 1
transform = transforms.Compose([
transforms.ToTensor(),
])
train_data = datasets.MNIST(
root='.',
train=True,
download=True,
transform=transform
)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
train_len = len(train_data)
means = torch.zeros((train_len,), dtype=torch.float32)
stds = torch.zeros((train_len,), dtype=torch.float32)
for idx, (image, label) in enumerate(train_loader):
image = torch.squeeze(image)
means[idx] = image.mean()
stds[idx] = image.std()
means = means.mean()
stds = stds.mean()
print(means, stds)
网友评论