美文网首页
tensorflow 读取自己的数据集用于训练

tensorflow 读取自己的数据集用于训练

作者: feden | 来源:发表于2019-03-13 22:50 被阅读0次

    ```

    #!/usr/bin/python

    # -*- coding: utf-8 -*-

    from __future__ import division

    import os

    import numpy as np

    import tensorflow as tf

    from utils import *

    image_path = "./ck/cohn-kanade-images/"

    label_path = "./ck/Emotion_labels/"

    Imglist, ImgLabellist, Labellist = GetImg_Label_list(image_path, label_path)  #得到图片和标签相应的列表

    def parse_data(filename,label):

        '''

        导入数据,进行预处理,输出两张图像,

        分别是输入图像和标签

        Args:

            filaneme, 图片的路径

        Returns:

            处理后图像,标签

        '''

        # 读取图像

        image = tf.read_file(filename)

        # 解码图片

        image = tf.image.decode_image(image)

        # 数据预处理,或者数据增强,这一步根据需要自由发挥

        image = tf.image.crop_to_bounding_box(image, 0, 0, 64, 64)

        # 数据增强,随机水平翻转图像

        image = tf.image.random_flip_left_right(image)

        # 图像归一化

        image = tf.cast(image, tf.float32) / 255.0

        return image, label

    def train_generator(batchsize, shuffle=True):

        with tf.Session() as sess:

            # 创建数据库

            train_dataset = tf.data.Dataset().from_tensor_slices((Imglist, Labellist))

            # 预处理数据

            train_dataset = train_dataset.map(parse_data)

            # 设置 batch size

            train_dataset = train_dataset.batch(batchsize)

            # 无限重复数据

            train_dataset = train_dataset.repeat()

            # 洗牌,打乱

            if shuffle:

                train_dataset = train_dataset.shuffle(buffer_size=4)

            # 创建迭代器

            train_iterator = train_dataset.make_initializable_iterator()

            sess.run(train_iterator.initializer)

            train_batch = train_iterator.get_next()

            # 开始生成数据

            while True:

                try:

                    x_batch, y_batch = sess.run(train_batch)

                    yield (x_batch, y_batch)

                except:

                    # 如果没有  train_dataset = train_dataset.repeat()

                    # 数据遍历完就到end了,就会抛出异常

                    train_iterator = train_dataset.make_initializable_iterator()

                    sess.run(train_iterator.initializer)

                    train_batch = train_iterator.get_next()

                    x_batch, y_batch = sess.run(train_batch)

                    yield (x_batch, y_batch)

    #检查结果是否正确

    x_batch = train_generator(16)

    for i in range(5):

        x,y = next(x_batch)

        print(x,y)    # 结果为一个batch 为16的数据包括图片和对应的标签

    ```

    相关文章

      网友评论

          本文标题:tensorflow 读取自己的数据集用于训练

          本文链接:https://www.haomeiwen.com/subject/uagxmqtx.html