作 者: 月牙眼的楼下小黑
联 系:zlf111@mail.ustc.edu.cn
声 明: 欢迎转载本文中的图片或文字,请说明出处
参考资料:
[1].PyTorch常用工具模块
1 数据处理
import torch
from torch.utils import data
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
show = ToPILImage()
1.1 数据加载
在 Pytorch
中, 数据加载可通过自定义一个继承Dataset
类的数据集对象, 并实现两个方法:
- __ getitem __ : 返回一条数据
- __ len __ : 返回样本数量
class DogCat(data.Dataset):
def __init__(self, root):
imgs = os.listdir(root) # root:图片所在文件夹路径
self.imgs = [os.path.join(root, img) for img in imgs] # imgs:图片文件路径列表
def __getitem__(self, index):
img_path = self.imgs[index]
if 'dog' in img_path.split('/')[-1]:
label =1
else:
label = 0
pil_img= Image.open(img_path) # 利用 python 图像处理标准库的 open 方法打开图片
array = np.asarray(pil_img) # 将 PIL.image 转化为 np. ndarray 形式, 默认为 channel last 形式: [height, width, channel]
#array = np.transpose(array, (2, 0, 1)) # 将 channel last 形式转化成 channel first 形式:[channel, height, width]
data = torch.from_numpy(array) # 将 np.ndarray 转化为 Tensor 形式
return data, label
def __len__(self):
return len(self.imgs)
补充: 对于三维矩阵的转置, 如 a.transpose(2,0,1)
, 意思是原矩阵a
中(aix 0, aix 1, aix 2)
处的值,现在成为了转置后矩阵 (aix 2, aix 0 , aix 1)
处的值。
in:
dataset = DogCat('/data1/zhanglf/myDLStudying/myDataSet/dog_cat_data/train/dogs')
in:
# 显示第一张图片
img,label= dataset[0]
plt.imshow(img) # 若为 channel last 形式的 tensor, 可用 matplotlib 中 imshow() 方法
print(label, img.size(), img.float().mean())
out:
1 torch.Size([500, 282, 3]) 169.23073522458628

in:
# 显示第一张图片
img,label= dataset[0]
plt.imshow(img) # 若为 channel first 形式的 tensor, 可用 transforms 中 的 ToPILImage() 方法
print(label, img.size(), img.float().mean())
在前面文章中提到过:ToPILImage
可以将
-
shape
为(C,H,W)
的Tensor
-
shape
为(H,W,C)
的numpy.ndarray
转化成PIL.Image
,值不变,方便可视化。注意到 它只能转变channel first
形式的Tensor
。而在上面的__getitem__
中,array = np.asarray(pil_img)
将 PIL.image
转化为 np. ndarray
形式, 默认为 channel last
形式: [height, width, channel]
。 所以如果我们要使用 ToPILImage
方法显示图片,在将 PIL.image 转化为 np. ndarray
形式后,还需要利用转置方法将 channel last
形式改成 channel first
形式:[channel, height, width]
1.2 数据预处理
torchvision.transforms
模块提供了对 PILImage
对象和Tensor
对象的常用操作。
对PILImage
的操作包括:
-
Scale
: 调整图片尺寸,长宽比保持不变 -
CenterCrop
、RandomCrop
、RandomSizeCrop
:裁剪图片 Pad
-
ToTensor
: 将PILImage
对象转化成channel first
的Tensor
并归一至[0,1]
对 Tensor
的操作包括:
-
Normalize
: 标准化,减均值,除以标准差 -
ToPlLImage
:将Tensor
转化为PILImage
对象
in:
trans = transforms.Resize((100,100))
image = Image.open('./dog.1.jpg')
print(image.size)
image = trans(image)
print(image.size)
out:
(327, 499)
(100, 100)
如果要对图片进行多个操作, 可通过Compose方法将这些操作拼接起来。
in:
transform = transforms.Compose([
transforms.Resize(224), # 缩放图片,保持长宽比不变,最短边为224像素
transforms.CenterCrop(224), # 从图片中间切出 224x224 的图片
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [1, 1, 1])
])
class DogCat(data.Dataset):
def __init__(self, root, transforms = None):
imgs = os.listdir(root) # root:图片所在文件夹路径
self.imgs = [os.path.join(root, img) for img in imgs] # imgs:图片文件路径列表
self.transforms = transforms
def __getitem__(self, index):
img_path = self.imgs[index]
if 'dog' in img_path.split('/')[-1]:
label =1
else:
label = 0
data = Image.open(img_path) # 利用 python 图像处理标准库的 open 方法打开图片
if self.transforms:
data = self.transforms(data)
return data, label
def __len__(self):
return len(self.imgs)
in:
dataset = DogCat('/data1/zhanglf/myDLStudying/myDataSet/dog_cat_data/train/dogs', transforms = transform)
img,label= dataset[0]
print(img.size())
show(img)
out:
torch.Size([3, 224, 224])

1.3 ImageFolder
torchvision
预先实现了常用的DataSet
,如CIFAR-10
, 可通过 torchvision.datasets.CIFAR10
来调用。这里介绍一个经常使用的 DataSet
——ImageFolder
. ImageFolder
假设所有文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名, ImageFolder
会根据文件夹名顺序自动生成 label
, 可以通过 class_to_idx
查看 label
和 文件夹名的映射关系。其构造函数如下:
ImageFolder(root, transform = None, target_transform = None, Loader = default_loader)
- root :文件夹路径
- transform: 对
PILImage
的转化操作,transform
的输入是loader
的返回对象 - target_transform: 对
label
的转化 - loader : 读取图片函数,默认为读取
RGB
格式的PILImage
对象
in:
from torchvision.datasets import ImageFolder
dataset = ImageFolder('./myDataSet/dog_cat_data/train')
in:
dataset.class_to_idx
out:
{'cats': 0, 'dogs': 1}
in:
# 此时还没有任何 transform, 返回的是 PILImage 对象
# 第一维指示第几张图片,第二维为 1 返回 label, 为 0 返回 图片数据
print(dataset[0][1])
dataset[0][0]
out:
0

1.4 DataLoader
调用DataSet
中的__getitem__
只返回一个样本,而我们需要batch wise training
,Pytorch
提供了 DataLoader
帮助我们实现这些功能。其构造函数如下:
DataLoader(dataset,
batch_size=1,
shuffle = False,
sample =None,
sampler = None,
num_workers =0,
collate_fn = default_collate,
pin_memory =False,
drop_last = False)
- dataset: 加载的数据集对象(
DataSet
对象) - shuffle: 是否将数据打乱
- sampler: 样本抽样
- num_workers: 使用多进程加载的进程数
- collate_fn: 如何将多个样本数据拼接成一个
patch
- pin_memory: 是否将数据保存在
pin memory
,pin memory
中的数据转到gpu
会快一些 - drop_last: 当
datast
中的数据个数不是batchsize
的整数倍,drop_last
为True
会将多出来不足一个batch
的数据丢弃。
2. torchvision
torchvision
是Pytorch
团队开发的独立于 Pytorch
的视觉工具包,通过pip install torchvision
安装,主要包含三部分:
- models:提供一系列经典已经预训练好的模型,包括
AlexNet
,VGG
,ResNet
,Inception
等 - datasets: 提供常用的数据集对象(
DataSet
对象), 包括MNIST
,CIFAR10/100
,ImageNet
,COCO
等 - transforms: 提供常用数据预处理工作,主要包括对
Tensor
和PILImage
对象的操作。
3. 可视化工具 Visdom
Visdom
可以创造、组织和共享多种数据的可视化,包括数值、图像、文本、视频, 支持 Pytorch
和Numpy
。.
-
Visdom
的安装:pip install visdom
-
Visdom
的启动:python -m visdom.server
, 打开浏览器输入:http://localhost:8097
, 8097 是默认端口号。
Visdom
中的两个重要概念:
- env: 环境。不同用户、不同程序一般使用不同
env
. 不同env
相互独立,互不影响。使用时不指定env
,则默认使用main
- pane: 窗格。 一个
env
中可以有多个不同的pane
, 每个pane
可视化或记录某一信息,可以拖动、缩放、保存或关闭
In:
import visdom
vis = visdom.Visdom(env=u'test1') # 构建一个客户端对象,创建一个名为' test1' 的 env
x = torch.arange(1, 30, 0.01)
y = torch.sin(x)
vis.line(X=x, Y=y, win='sinx', opts={'title': 'y=sin(x)'}) # win 是 pane 名字,opts 设置 pane 格式,如 title, xlabel,ylabel

在
vis.line(X=x, Y=y, win='sinx', opts={'title': 'y=sin(x)'})
中, win
参数指定 pane
名字, 如果不指定,visdom
将自动分配一个新的pane
. 如果两次操做指定的win
名字一样,新操作将覆盖当前 pane
的内容。如在 上面的 pane
中画 y = x
函数,原来的 y = sin(x)
将被覆盖。
In:
y = x
vis.line(X=x, Y=y, win='sinx', opts={'title': 'y=x'})

如果不想覆盖原图,可以使用
updateTrace
方法,如:
y = x + 1
vis.updateTrace(X=x, Y=y, win='sinx', name='this is a new Trace')

网友评论