美文网首页
PyTrch深度学习简明实战21 - 标注图片的数据预处理

PyTrch深度学习简明实战21 - 标注图片的数据预处理

作者: 薛东弗斯 | 来源:发表于2023-04-08 12:55 被阅读0次
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import torchvision
from torchvision import transforms
import os
import glob
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

plt.figure(figsize = (12,8))
img = Image.open(r'./data//Oxford-IIIT Pets Dataset/dataset/images/Abyssinian_1.jpg')
anno = Image.open(r'./data/Oxford-IIIT Pets Dataset/dataset/annotations/trimaps/Abyssinian_1.png')
plt.subplot(1,2,1)
plt.imshow(np.array(img))
plt.subplot(1,2,2)
plt.imshow(np.array(anno))

image.png
# np.unique(np.array(anno))  # array([1, 2, 3], dtype=uint8)   类别编码从1开始

images = glob.glob(r'./data//Oxford-IIIT Pets Dataset/dataset/images/*.jpg')  # 获取所有图片的路径
annotations = glob.glob(r'./data/Oxford-IIIT Pets Dataset/dataset/annotations/trimaps/*.png')  # 获取所有标注的路径
# len(images)   # 7390
# len(annotations)   # 7390

np.random.seed(2022)
index = np.random.permutation(len(images))
images = np.array(images)[index]
annotations = np.array(annotations)[index]

sep = int(len(images)*0.8)
train_imgs = images[:sep]
train_annos = annotations[:sep]
test_imgs = images[sep:]
test_annos = annotations[sep:]

# 编写tranform,对image数据集进行变换
# 标注图annotations=0,1,2,3..., 不能使用transform。 因为transform会有ToTensor方法,将标注分类归一化。
transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor()  # transform 只能应用在原图,不能使用在标注图片上面。
])

class Oxford_dataset(data.Dataset):
    def __init__(self,img_path,anno_path):
        self.imgs = img_path
        self.annos = anno_path
    def __getitem__(self,index):
        img = self.imgs[index]    # 输入数据类型默认为float32
        anno = self.annos[index]
        pil_img = Image.open(img)
        pil_img = pil_img.convert('RGB')  # 假如图像中有黑白图片,转换为RGB格式
        img_tensor = transform(pil_img)
        # 标注图处理. 标注图打开以后,输出已经是类别值0/1/2/3...,
        # 只需要3步; 1-Image.open打开图片   2-直接使用对象.resize方法改变大小, 3- 然后转换为tensor。千万不要用transform
        pil_anno = Image.open(anno)
        pil_anno = pil_anno.resize((256,256))
        anno_tensor = torch.tensor(np.array(pil_anno),dtype = torch.int64)  # 先经过np.array转换为ndarry,再转换为tensor格式
        return img_tensor,anno_tensor-1  # 标注信息从1开始,需要减一
    def __len__(self):
        return len(self.imgs)
    
train_dataset = Oxford_dataset(train_imgs,train_annos)
test_dataset = Oxford_dataset(test_imgs,test_annos)

BATCHSIZE=8

train_dl = data.DataLoader(train_dataset,
                          batch_size = BATCHSIZE,
                          shuffle = True)

test_dl = data.DataLoader(test_dataset,
                         batch_size = BATCHSIZE
                         )

img_batch,anno_batch = next(iter(train_dl))
# img_batch.shape    # torch.Size([8, 3, 256, 256])
# img = img_batch[0].permute(1,2,0).numpy()  # 将channel放到最后,并转换为numpy
# np.unique(anno)   # array([1, 2, 3], dtype=uint8)

后面,就是训练部分代码。

相关文章

网友评论

      本文标题:PyTrch深度学习简明实战21 - 标注图片的数据预处理

      本文链接:https://www.haomeiwen.com/subject/vctvddtx.html