美文网首页
基于Pytorch计算图片数据集各通道均值与方差

基于Pytorch计算图片数据集各通道均值与方差

作者: 深思海数_willschang | 来源:发表于2021-08-16 16:34 被阅读0次

    该方法来自于《Deep Learning with PyTorch Step by Step》一书第六章。


    image.png

    直接上代码

    实际应用中可以根据自己项目需要再进行优化,作为常用函数。

    %matplotlib inline
    import matplotlib.pyplot as plt
    import numpy as np
    from PIL import Image
    from copy import deepcopy
    
    import torch
    import torch.optim as optim
    import torch.nn as nn
    import torch.nn.functional as F
    
    from torch.utils.data import DataLoader, TensorDataset, random_split
    from torchvision.transforms import Compose, ToTensor, Normalize, ToPILImage, Resize
    from torchvision.datasets import ImageFolder
    
    # 对数据进行预处理并通过ImageFolder进行读取处理
    temp_transform = Compose([Resize(28), ToTensor()])
    # 图片地址根据实际情况填写
    temp_dataset = ImageFolder(root='./data/rps', transform=temp_transform)
    # 构建数据加载器loader
    temp_loader = DataLoader(temp_dataset, batch_size=16)
    
    
    # 计算图片各通道的均值与方差
    class GetChannelsNormalize():
        def __init__():
            pass
        
        @staticmethod
        def loader_apply(loader, func, reduce='sum'):
            results = [func(x, y) for i, (x, y) in enumerate(loader)]
            results = torch.stack(results, axis=0)
    
            if reduce == 'sum':
                results = results.sum(axis=0)
            elif reduce == 'mean':
                results = results.float().mean(axis=0)
    
            return results
        
        @staticmethod
        def statistics_per_channel(images, labels):
            # NCHW
            n_samples, n_channels, n_height, n_weight = images.size()
            # Flatten HW into a single dimension
            flatten_per_channel = images.reshape(n_samples, n_channels, -1)
    
            # Computes statistics of each image per channel
            # Average pixel value per channel 
            # (n_samples, n_channels)
            means = flatten_per_channel.mean(axis=2)
            # Standard deviation of pixel values per channel
            # (n_samples, n_channels)
            stds = flatten_per_channel.std(axis=2)
    
            # Adds up statistics of all images in a mini-batch
            # (1, n_channels)
            sum_means = means.sum(axis=0)
            sum_stds = stds.sum(axis=0)
            # Makes a tensor of shape (1, n_channels)
            # with the number of samples in the mini-batch
            n_samples = torch.tensor([n_samples]*n_channels).float()
    
            # Stack the three tensors on top of one another
            # (3, n_channels)
            return torch.stack([n_samples, sum_means, sum_stds], axis=0)
        
        @staticmethod
        def make_normalizer(loader):
            total_samples, total_means, total_stds = loader_apply(loader, statistics_per_channel)
            norm_mean = total_means / total_samples
            norm_std = total_stds / total_samples
            
            return Normalize(mean=norm_mean, std=norm_std)
        
        
        
    norm_data = GetChannelsNormalize.make_normalizer(temp_loader)
    print(norm_data)
    # Normalize(mean=tensor([0.8502, 0.8215, 0.8116]), std=tensor([0.2089, 0.2512, 0.2659]))
    
    

    相关文章

      网友评论

          本文标题:基于Pytorch计算图片数据集各通道均值与方差

          本文链接:https://www.haomeiwen.com/subject/jfambltx.html