美文网首页
tensorflow导入cifar10

tensorflow导入cifar10

作者: sphinx_catie | 来源:发表于2018-12-07 17:08 被阅读0次

记录一下,小白导入cifar10,python3.7,用谷歌的文件导入cifar10老是出错,于是自己写了一个导入文件。

主要是导入老是报错,UnboundLocalError: local variable 'a' referenced before assignment。

希望能帮到同样遇到该问题的小白们。

输入path = ‘你自己的文件路径(为数据的上层文件夹)’
然后load_train(path),load_test(path)即可。

import matplotlib.pyplot as plt
from scipy.misc import imsave
import numpy as np
import os
import pickle

#path = '/Users/wuyanqing/CIFAR_batches'

def unpickle(file):
    fo = open(file, 'rb')
    dict = pickle.load(fo, encoding='bytes')
    fo.close()
    return dict

def load_file(file):
    Xtr = unpickle(file)
    keys = list(Xtr.keys())
    #batch_label = Xtr[keys[0]]
    labels = Xtr[keys[1]]
    labels = np.array(labels)
    data = Xtr[keys[2]]
    data = data.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")
    #filenames = Xtr[keys[3]]
    #imgs = data.reshape(10000, 3, 32, 32).transpose(0,2,3,1)
    return data, labels


def load_train(path):
    xs = []
    ys = []

    for i in range(1,6):
        f = os.path.join(path,'data_batch_%d' %i)
        X, Y = load_file(f)
        xs.append(X)
        ys.append(Y)
        Xtr = np.concatenate(xs)
        Ytr = np.concatenate(ys)
    return Xtr, Ytr
    del X, Y
#X_train, Y_train = load_train(path)
    
def load_test(path):
    f = os.path.join(path, 'test_batch')
    X, Y = load_file(f)
    return X, Y

#X_test, Y_test = load_test(path)

#X_train.shape
#Y_train.shape
#X_test.shape
#Y_test.shape

没有队列读取,数据增强的时候巨慢,还要继续修改。。。。。

相关文章

网友评论

      本文标题:tensorflow导入cifar10

      本文链接:https://www.haomeiwen.com/subject/zifxhqtx.html