记录一下,小白导入cifar10,python3.7,用谷歌的文件导入cifar10老是出错,于是自己写了一个导入文件。
主要是导入老是报错,UnboundLocalError: local variable 'a' referenced before assignment。
希望能帮到同样遇到该问题的小白们。
输入path = ‘你自己的文件路径(为数据的上层文件夹)’
然后load_train(path),load_test(path)即可。
import matplotlib.pyplot as plt
from scipy.misc import imsave
import numpy as np
import os
import pickle
#path = '/Users/wuyanqing/CIFAR_batches'
def unpickle(file):
fo = open(file, 'rb')
dict = pickle.load(fo, encoding='bytes')
fo.close()
return dict
def load_file(file):
Xtr = unpickle(file)
keys = list(Xtr.keys())
#batch_label = Xtr[keys[0]]
labels = Xtr[keys[1]]
labels = np.array(labels)
data = Xtr[keys[2]]
data = data.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")
#filenames = Xtr[keys[3]]
#imgs = data.reshape(10000, 3, 32, 32).transpose(0,2,3,1)
return data, labels
def load_train(path):
xs = []
ys = []
for i in range(1,6):
f = os.path.join(path,'data_batch_%d' %i)
X, Y = load_file(f)
xs.append(X)
ys.append(Y)
Xtr = np.concatenate(xs)
Ytr = np.concatenate(ys)
return Xtr, Ytr
del X, Y
#X_train, Y_train = load_train(path)
def load_test(path):
f = os.path.join(path, 'test_batch')
X, Y = load_file(f)
return X, Y
#X_test, Y_test = load_test(path)
#X_train.shape
#Y_train.shape
#X_test.shape
#Y_test.shape
没有队列读取,数据增强的时候巨慢,还要继续修改。。。。。
网友评论