import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import torchvision
from torchvision import transforms
import os
import glob
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
plt.figure(figsize = (12,8))
img = Image.open(r'./data//Oxford-IIIT Pets Dataset/dataset/images/Abyssinian_1.jpg')
anno = Image.open(r'./data/Oxford-IIIT Pets Dataset/dataset/annotations/trimaps/Abyssinian_1.png')
plt.subplot(1,2,1)
plt.imshow(np.array(img))
plt.subplot(1,2,2)
plt.imshow(np.array(anno))
![](https://img.haomeiwen.com/i3968643/617bfe48dc2947ed.png)
# np.unique(np.array(anno)) # array([1, 2, 3], dtype=uint8) 类别编码从1开始
images = glob.glob(r'./data//Oxford-IIIT Pets Dataset/dataset/images/*.jpg') # 获取所有图片的路径
annotations = glob.glob(r'./data/Oxford-IIIT Pets Dataset/dataset/annotations/trimaps/*.png') # 获取所有标注的路径
# len(images) # 7390
# len(annotations) # 7390
np.random.seed(2022)
index = np.random.permutation(len(images))
images = np.array(images)[index]
annotations = np.array(annotations)[index]
sep = int(len(images)*0.8)
train_imgs = images[:sep]
train_annos = annotations[:sep]
test_imgs = images[sep:]
test_annos = annotations[sep:]
# 编写tranform,对image数据集进行变换
# 标注图annotations=0,1,2,3..., 不能使用transform。 因为transform会有ToTensor方法,将标注分类归一化。
transform = transforms.Compose([
transforms.Resize((256,256)),
transforms.ToTensor() # transform 只能应用在原图,不能使用在标注图片上面。
])
class Oxford_dataset(data.Dataset):
def __init__(self,img_path,anno_path):
self.imgs = img_path
self.annos = anno_path
def __getitem__(self,index):
img = self.imgs[index] # 输入数据类型默认为float32
anno = self.annos[index]
pil_img = Image.open(img)
pil_img = pil_img.convert('RGB') # 假如图像中有黑白图片,转换为RGB格式
img_tensor = transform(pil_img)
# 标注图处理. 标注图打开以后,输出已经是类别值0/1/2/3...,
# 只需要3步; 1-Image.open打开图片 2-直接使用对象.resize方法改变大小, 3- 然后转换为tensor。千万不要用transform
pil_anno = Image.open(anno)
pil_anno = pil_anno.resize((256,256))
anno_tensor = torch.tensor(np.array(pil_anno),dtype = torch.int64) # 先经过np.array转换为ndarry,再转换为tensor格式
return img_tensor,anno_tensor-1 # 标注信息从1开始,需要减一
def __len__(self):
return len(self.imgs)
train_dataset = Oxford_dataset(train_imgs,train_annos)
test_dataset = Oxford_dataset(test_imgs,test_annos)
BATCHSIZE=8
train_dl = data.DataLoader(train_dataset,
batch_size = BATCHSIZE,
shuffle = True)
test_dl = data.DataLoader(test_dataset,
batch_size = BATCHSIZE
)
img_batch,anno_batch = next(iter(train_dl))
# img_batch.shape # torch.Size([8, 3, 256, 256])
# img = img_batch[0].permute(1,2,0).numpy() # 将channel放到最后,并转换为numpy
# np.unique(anno) # array([1, 2, 3], dtype=uint8)
后面,就是训练部分代码。
网友评论