代码如下:
首先导入必须的包
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt
###########################################
下载数据,然后按照用DataLoader封装
batch_size = 64 # 超参数
lr = 1e-3
num_epochs = 100
cifar_train = datasets.CIFAR10(root='./colorful_data', train=True, transform=transforms.ToTensor(), download=True)
cifar_test = datasets.CIFAR10(root='./colorful_data', train=False, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(cifar_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(cifar_test, batch_size=batch_size)
############################################
显示彩色图片:plt.imshow(cifar_train.data[5][:, :, :])
print(cifar_train.data.shape)
print(train_loader.dataset.data.shape)
plt.imshow(cifar_train.data[5][:, :, :])
plt.show()
############################################
结果如下
Files already downloaded and verified
Files already downloaded and verified
(50000, 32, 32, 3)
(50000, 32, 32, 3)
第五张图片.jpg
网友评论