美文网首页
基于Pytorch计算图片数据集各通道均值与方差

基于Pytorch计算图片数据集各通道均值与方差

作者: 深思海数_willschang | 来源:发表于2021-08-16 16:34 被阅读0次

该方法来自于《Deep Learning with PyTorch Step by Step》一书第六章。


image.png

直接上代码

实际应用中可以根据自己项目需要再进行优化,作为常用函数。

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from copy import deepcopy

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader, TensorDataset, random_split
from torchvision.transforms import Compose, ToTensor, Normalize, ToPILImage, Resize
from torchvision.datasets import ImageFolder

# 对数据进行预处理并通过ImageFolder进行读取处理
temp_transform = Compose([Resize(28), ToTensor()])
# 图片地址根据实际情况填写
temp_dataset = ImageFolder(root='./data/rps', transform=temp_transform)
# 构建数据加载器loader
temp_loader = DataLoader(temp_dataset, batch_size=16)


# 计算图片各通道的均值与方差
class GetChannelsNormalize():
    def __init__():
        pass
    
    @staticmethod
    def loader_apply(loader, func, reduce='sum'):
        results = [func(x, y) for i, (x, y) in enumerate(loader)]
        results = torch.stack(results, axis=0)

        if reduce == 'sum':
            results = results.sum(axis=0)
        elif reduce == 'mean':
            results = results.float().mean(axis=0)

        return results
    
    @staticmethod
    def statistics_per_channel(images, labels):
        # NCHW
        n_samples, n_channels, n_height, n_weight = images.size()
        # Flatten HW into a single dimension
        flatten_per_channel = images.reshape(n_samples, n_channels, -1)

        # Computes statistics of each image per channel
        # Average pixel value per channel 
        # (n_samples, n_channels)
        means = flatten_per_channel.mean(axis=2)
        # Standard deviation of pixel values per channel
        # (n_samples, n_channels)
        stds = flatten_per_channel.std(axis=2)

        # Adds up statistics of all images in a mini-batch
        # (1, n_channels)
        sum_means = means.sum(axis=0)
        sum_stds = stds.sum(axis=0)
        # Makes a tensor of shape (1, n_channels)
        # with the number of samples in the mini-batch
        n_samples = torch.tensor([n_samples]*n_channels).float()

        # Stack the three tensors on top of one another
        # (3, n_channels)
        return torch.stack([n_samples, sum_means, sum_stds], axis=0)
    
    @staticmethod
    def make_normalizer(loader):
        total_samples, total_means, total_stds = loader_apply(loader, statistics_per_channel)
        norm_mean = total_means / total_samples
        norm_std = total_stds / total_samples
        
        return Normalize(mean=norm_mean, std=norm_std)
    
    
    
norm_data = GetChannelsNormalize.make_normalizer(temp_loader)
print(norm_data)
# Normalize(mean=tensor([0.8502, 0.8215, 0.8116]), std=tensor([0.2089, 0.2512, 0.2659]))

相关文章

网友评论

      本文标题:基于Pytorch计算图片数据集各通道均值与方差

      本文链接:https://www.haomeiwen.com/subject/jfambltx.html