美文网首页
TensorFlow(二)制做自己的数据集

TensorFlow(二)制做自己的数据集

作者: 续袁 | 来源:发表于2019-05-13 14:15 被阅读0次

1.数据集格式

2. 数据集制作代码

import numpy as np
import os
import tensorflow as tf
import matplotlib.pyplot as plt
IMAGE_SIZE = 160
# Global constants describing the CIFAR-10 data set.
NUM_CLASSES = 2
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 1018
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 100

def make_data_set():
    count = 0
    arr_train = [[]]
    arr_test = [[]]
    #lable=[]
    lable_list_train=[]
    lable_list_test = []
    data_path="E:\\research\\data_more\\frontal_face_160"
    print(len([name for name in os.listdir(data_path) if os.path.isfile(os.path.join(data_path, name))]))
    for files in os.listdir(data_path):
        print(files)
        image_count =0
        syndrome_image_count = 0
        for file in os.listdir(data_path + "\\" + files):

           # if count<1:
                # fi_d = os.path.join(path, file)
                # print(file)
                # print(files)
                syndrome_path = data_path + "\\" + files
                image_count = len(
                    [name for name in os.listdir(syndrome_path) if os.path.isfile(os.path.join(syndrome_path, name))])
                print(image_count)
                count = count + 1
                syndrome_image_count = syndrome_image_count +1
                img_path = data_path + "\\" + files + "\\" + file
                print(img_path)
                image_raw = tf.gfile.FastGFile(img_path, 'rb').read()
                # 解码为tf中的图像格式
                IMAGE_SIZE=160
                height = IMAGE_SIZE
                width = IMAGE_SIZE
                imgdata = tf.image.decode_jpeg(image_raw)  # Tensor
                imgdata = tf.random_crop(imgdata, [height, width, 3])
                print("imgdata:",imgdata.shape)
                float_image = tf.image.per_image_standardization(imgdata)
                with tf.Session() as sess:
                    img3 = imgdata.eval()
                    img0 = sess.run(imgdata)
                    result = sess.run(float_image)
                    # result = sess.run(distorted_image)
                #print(img3)
                '''
                print("打印原图片:")
                print(img0.shape)
                print(img0)
                print(img3.shape)
                print(img3)
                print("打印图片:")
                print(result.shape)
                print(result)
                '''

                r,g,b = img3[:, :, 0],img3[:, :, 1],img3[:, :, 2]
                #print(r,g,b)
                r_array = np.array(r).reshape([25600])
                g_array = np.array(g).reshape([25600])
                b_array = np.array(b).reshape([25600])
                if files == "22q11.2 Deletion syndrome (DiGeorge syndrome and Velocardiofacial syndrome)":
                  lable = 0

                elif files == "Noonan syndrome":
                  lable = 1

                elif files == "Trisomy 21 (Down syndrome)":
                  lable = 2

                elif files == "Turner syndrome":
                  lable = 3

                elif files == "Williams syndrome":
                  lable = 4
                #merge_array = np.concatenate((lable,r_array,g_array,b_array))
                merge_array = np.concatenate(( r_array, g_array, b_array))
                print(merge_array)
                print(merge_array.shape)

                if syndrome_image_count*5<image_count:
                   print("进入测试集:")
                   if arr_test == [[]]:
                        print("测试集为空:")
                        arr_test = [merge_array]
                        print(lable)
                        print(type(lable_list_test))
                        print(lable_list_test)
                        lable_list_test.append(lable)
                        print(lable_list_test)
                        print(type(lable_list_test))
                        continue
                   print("不为空:")
                   arr_test = np.concatenate((arr_test, [merge_array]),axis=0)
                   print(lable)
                   print(lable_list_test)
                   print(type(lable_list_test))
                   lable_list_test.append(lable)
                   print(lable_list_test)
                   print(type(lable_list_test))
                else:
                   print("进入训练集:")
                   if arr_train == [[]]:
                        arr_train = [merge_array]
                        lable_list_train.append(lable)
                        continue
                   arr_train = np.concatenate((arr_train, [merge_array]), axis=0)
                   lable_list_train.append(lable)
                print("success!")
                # ff='/'+filename+"/"+file+" "+filename+"\n"
    # 生成测试集
    label_arr_test = np.array(lable_list_test).reshape(len(lable_list_test), 1)
    label_uint8_test = label_arr_test.astype(np.uint8)
    arr1_test = np.hstack((label_uint8_test, arr_test))
    print(arr1_test.shape)
    print(type(arr1_test.flat))
    print("count:",count)
    count1 = 0
    with open('./bin/face_test_160_tf', 'wb') as f:
        for element in arr1_test.flat:
            f.write(element)
    #生成训练集
    label_arr_train = np.array(lable_list_train).reshape(len(lable_list_train), 1)
    label_uint8_train = label_arr_train.astype(np.uint8)
    arr1_train = np.hstack((label_uint8_train, arr_train))
    print(arr1_train.shape)
    print(type(arr1_train.flat))
    print("count:", count)

    with open('./bin/face_train_160_tf', 'wb') as f:
        for element in arr1_train.flat:
            f.write(element)

    for element in arr1_test.flat:
        count1 = count1 + 1
        if (count1 < 10):
            print(element)
    '''
     with open('./bin/face_train_160_tf', 'wb') as f:
        for element in arr1.flat:
            f.write(element)
    '''

if __name__ == '__main__':
    make_data_set()

参考资料

[1] TensorFlow 制作自己的TFRecord数据集 读取、显示及代码详解
[2] Numpy学习(2):将cifar10/100数据文件读入到python数据结构(字典)中
[3] 用自己的图片构建cifar10 binary格式的数据十分重要
[4] Numpy学习(4):自己动手制作类似于cifar10这样的图像数据集
[5] 用自己的数据,制作python版本的cifar10数据集
[6] 【Python实现卷积神经网络】:用自己的图片制作cifar-10格式的数据用于测试神经网络+python实现代码
[7] 制作自己的数据集之2 将图片存储为cifar的Python3数据格式 --- png图片存为类cifar10的二进制数据
[8] # CIFAR10/CIFAR100数据集介绍 重要
[9]制作自己的python版本的类CIFAR10数据集
[10] CIFAR-10和python读取
[11] Tensorflow(四)- CNN_CIFAR(一)- cifar10_input 重要
[11]tensorflowxun训练自己的数据集之从tfrecords读取数据
[12] tensorflow数据读取之tfrecords
[13] # TensorFlow走过的坑之---数据读取和tf中batch的使用方法

[4] Tensorflow中关于FixedLengthRecordReader()的理解
[5] Tensorflow-tf.FixedLengthRecordReader详解
[6] 『TensorFlow』读书笔记进阶卷积神经网络分类cifar10_下

[7] tensorflow学习——tfreader格式,队列读取数据tf.train.shuffle_batch()

数据集

[1] cifar10

相关文章

网友评论

      本文标题:TensorFlow(二)制做自己的数据集

      本文链接:https://www.haomeiwen.com/subject/vjokuqtx.html