美文网首页
【TensorFlow2.0】数据读取与使用方式

【TensorFlow2.0】数据读取与使用方式

作者: 有三AI | 来源:发表于2019-06-13 19:01 被阅读0次

    大家好,这是专栏《TensorFlow2.0》的第三篇文章,讲述如何使用TensorFlow2.0读取和使用自己的数据集。

    如果您正在学习计算机视觉,无论你通过书籍还是视频学习,大部分的教程讲解的都是MNIST等已经为用户打包封装好的数据集,用户只需要load_data即可实现数据导入。但是在我们平时使用时,无论您是做分类还是检测或者分割任务,我们不可能每次都能找到打包好的数据集使用,大多数时候我们使用的都是自己的数据集,也就是我们需要从本地读取文件。因此我们是很有必要学会数据预处理这个本领的。本篇文章,我们就聊聊如何使用TensorFlow2.0对自己的数据集进行处理。

    作者&编辑 | 汤兴旺

    在TensorFlow2.0中,对数据处理的方法有很多种,下面我主要介绍两种我自认为最好用的数据预处理的方法。

    1 使用Keras API对数据进行预处理

    1.1 数据集

    本文用到的数据集是表情分类数据集,数据集有1000张图片,包括500张微笑图片,500张非微笑图片。图片预览如下:

    微笑图片:

    非微笑图片:

    数据集结构组织如下:

    其中800张图片用来训练,200张用来测试,每个类别的样本也是相同的。

    1.2 数据预处理

    我们知道,在将数据输入神经网络之前,需要将数据格式化为经过预处理的浮点数张量。现在我们看看数据预处理的步骤,如下图:

    这个步骤虽然看起来比较复杂,但在TensorFlow2.0的高级API Keras中有个比较好用的图像处理的类ImageDataGenerator,它可以将本地图像文件自动转换为处理好的张量。

    接下来我们通过代码来解释如何利用Keras来对数据预处理,完整代码如下:

    from tensorflow.keras.preprocessing.image import ImageDataGenerator

    train_data_dir = r"D://Learning//tensorflow_2.0//smile//data//train"

    val_data_dir = r"D://Learning//tensorflow_2.0//smile//data//val"

    img_width,img_height = 48,48

    batch_size = 16

    train_datagen = ImageDataGenerator(

           rescale=1./255,

           shear_range=0.2,

           horizontal_flip=True)

    val_datagen = ImageDataGenerator(rescale=1. / 255)

    train_generator = train_datagen.flow_from_directory(

           train_data_dir,

           target_size=(img_width, img_height),

           batch_size=batch_size)

    val_generator = val_datagen.flow_from_directory(

           val_data_dir,

           target_size=(img_width, img_height),

           batch_size=batch_size)

    在上面的代码中,我们首先导入ImageDataGenerator,即下面代码:

    from tensorflow.keras.preprocessing.image import ImageDataGenerator

    ImageDataGenerator是tensorflow.keras.preprocessing.image模块中的图片生成器,同时也可以使用它在batch中对数据进行增强,扩充数据集大小,从而增强模型的泛化能力。

    ImageDataGenerator中有众多的参数,如下:

    tf.keras.preprocessing.image.ImageDataGenerator(

       featurewise_center=False,

       samplewise_center=False,

       featurewise_std_normalization=False,

       samplewise_std_normalization=False,

       zca_whitening=False,

       zca_epsilon=1e-6,

       rotation_range=0.,

       width_shift_range=0.,

       height_shift_range=0.,

       brightness_range,    

       shear_range=0.,

       zoom_range=0.,

       channel_shift_range=0.,

       fill_mode='nearest',

       cval=0.,

       horizontal_flip=False,

       vertical_flip=False,

       rescale=None,

       preprocessing_function=None,

       data_format=K.image_data_format())

    具体含义如下:

    featurewise_center:布尔值,使输入数据集去中心化(均值为0)

    samplewise_center:布尔值,使输入数据的每个样本均值为0。

    featurewise_std_normalization:布尔值,将输入除以数据集的标准差以完成标准化。

    samplewise_std_normalization:布尔值,将输入的每个样本除以其自身的标准差。

    zca_whitening:布尔值,对输入数据施加ZCA白化。

    rotation_range:整数,数据增强时图片随机转动的角度。随机选择图片的角度,是一个0~180的度数,取值为0~180。

    width_shift_range:浮点数,图片宽度的某个比例,数据增强时图片随机水平偏移的幅度。

    height_shift_range:浮点数,图片高度的某个比例,数据增强时图片随机竖直偏移的幅度。 

    shear_range:浮点数,剪切强度(逆时针方向的剪切变换角度)。是用来进行剪切变换的程度。

    zoom_range:浮点数或形如[lower,upper]的列表,随机缩放的幅度,若为浮点数,则相当于[lower,upper] = [1 - zoom_range,1+zoom_range]。用来进行随机的放大。

    channel_shift_range:浮点数,随机通道偏移的幅度。

    fill_mode:‘constant’,‘nearest’,‘reflect’或‘wrap’之一,当进行变换时超出边界的点将根据本参数给定的方法进行处理。

    cval:浮点数或整数,当fill_mode=constant时,指定要向超出边界的点填充的值。

    horizontal_flip:布尔值,进行随机水平翻转。随机的对图片进行水平翻转,这个参数适用于水平翻转不影响图片语义的时候。

    vertical_flip:布尔值,进行随机竖直翻转。

    rescale: 值将在执行其他处理前乘到整个图像上,我们的图像在RGB通道都是0~255的整数,这样的操作可能使图像的值过高或过低,所以我们将这个值定为0~1之间的数。

    preprocessing_function: 将被应用于每个输入的函数。该函数将在任何其他修改之前运行。该函数接受一个参数,为一张图片(秩为3的numpy array),并且输出一个具有相同shape的numpy array。

    下面看看我们对数据集增强后的一个效果,由于图片数量太多,我们显示其中9张图片,增强后图片如下:

    大家可以多尝试下每个增强后的效果,增加些感性认识,数据增强和图片显示代码如下,只需要更改ImageDataGenerator中的参数,就能看到结果。

    import matplotlib.pyplot as plt

    from PIL import Image

    from tensorflow.keras.preprocessing import image

    import glob

    datagen = ImageDataGenerator(rotation_range=30,rescale=1./255,

           shear_range=0.2,horizontal_flip=True)

    gen_data=datagen.flow_from_directory(

            r"D://Learning//tensorflow_2.0//smile//datas//mouth//test", 

           batch_size=1,

      shuffle=False,                           save_to_dir=r"D://Learning//tensorflow_2.0//smile//datas//mouth//model",

           save_prefix='gen',

           target_size=(48, 48))

    for i in range(9):

       gen_data.next()

    name_list=glob.glob(r"D://Learning//tensorflow_2.0//smile//datas//mouth//model"+'/*')

    fig = plt.figure()

    for i in range(9):

       img = Image.open(name_list[i])

       sub_img = fig.add_subplot(331 + i)

       sub_img.imshow(img)

    plt.show()

    说完了数据增强,我们再看下ImageGenerator类下的函数flow_from_diectory。从这个函数名,我们也明白其就是从文件夹中读取图像。

    train_generator = train_datagen.flow_from_directory(

           train_data_dir,

           target_size=(img_width, img_height),

           batch_size=batch_size)

    flow_from_diectory中有如下参数:

    directory:目标文件夹路径,对于每一个类,该文件夹都要包含一个子文件夹。

    target_size:整数tuple,默认为(256, 256)。图像将被resize成该尺寸

    color_mode:颜色模式,为"grayscale"和"rgb"之一,默认为"rgb",代表这些图片是否会被转换为单通道或三通道的图片。

    classes:可选参数,为子文件夹的列表,如['smile','neutral'],默认为None。若未提供,则该类别列表将从directory下的子文件夹名称/结构自动推断。每一个子文件夹都会被认为是一个新的类。(类别的顺序将按照字母表顺序映射到标签值)。

    class_mode: "categorical", "binary", "sparse"或None之一。默认为"categorical。该参数决定了返回的标签数组的形式,"categorical"会返回2D的one-hot编码标签,"binary"返回1D的二值标签。"sparse"返回1D的整数标签,如果为None则不返回任何标签,生成器将仅仅生成batch数据。

    batch_size:batch数据的大小,默认32。

    shuffle:是否打乱数据,默认为True。

    seed:可选参数,打乱数据和进行变换时的随机数种子。

    save_to_dir:None或字符串,该参数能让你将数据增强后的图片保存起来,用以可视化。

    save_prefix:字符串,保存数据增强后图片时使用的前缀, 仅当设置了save_to_dir时生效。

    save_format:"png"或"jpeg"之一,指定保存图片的数据格式,默认"jpeg"。

    这些参数中的directory一定要弄清楚,它是指类别文件夹的上一层文件夹,在该数据集中,类别文件夹为smile和neutral,它的上一级文件夹是train。所以director为 r"D://Learning//tensorflow_2.0//smile//data//train"

    另外,class这个参数也要注意,通常我们就采用默认None,directory的子文件夹就是标签。在该分类任务中标签就是smile和neutral。

    以上就是在TensorFlow2.0中利用Keras这个高级API来对分类任务中的数据进行预处理。另外如果您需要完成一个目标检测等任务,则需要自定义一个类来继承ImageDataGeneraton。具体怎么操作,请期待我们的下回关于如何利用TensorFlow2.0处理目标检测任务的分享。

    2 使用Dataset类对数据预处理

    由于该方法在TensorFlow1.x版本中也有,大家可以比较查看2.0相对于1.x版本的改动地方。下面是TensorFlow2.0中使用的完整代码:

    import tensorflow as tf

    #from tensorflow.contrib.data import Dataset

    #from tensorflow.python.framework import dtypes

    #from tensorflow.python.framework.ops import convert_to_tensor

    txtfile=r"D://Learning//tensorflow_2.0//smile//datas//train//train.txt"

    batch_size = 64

    num_classes = 2

    image_size = (48,48)

    class ImageData:

       def read_txt_file(self):

           self.img_paths = []

           self.labels = []

           for line in open(self.txt_file, 'r'):

               items = line.split(' ')

               self.img_paths.append(items[0])

               self.labels.append(int(items[1]))

       def __init__(self, txt_file, batch_size, num_classes,

                    image_size, buffer_scale=100):

           self.image_size = image_size

           self.batch_size = batch_size

           self.txt_file = txt_file  ##txt list file,stored as: imagename id

           self.num_classes = num_classes

           buffer_size = batch_size * buffer_scale

           # 读取图片

           self.read_txt_file()

           self.dataset_size = len(self.labels)

           print("num of train datas=", self.dataset_size)

          # 转换成Tensor

           #self.img_paths=convert_to_tensor(self.img_paths, dtype=dtypes.string)

           #self.labels =convert_to_tensor(self.labels, dtype=dtypes.int32)

           # 转换成Tensor

           self.img_paths = tf.convert_to_tensor(self.img_paths, dtype=tf.string)

           self.labels = tf.convert_to_tensor(self.labels, dtype=tf.int32)

           # 创建数据集

           data = tf.data.Dataset.from_tensor_slices((self.img_paths, self.labels))

           print("data type=", type(data))

           data = data.map(self.parse_function)

           data = data.repeat(1000)

           data = data.shuffle(buffer_size=buffer_size)

           # 设置self data Batch

           self.data = data.batch(batch_size)

           print("self.data type=", type(self.data))

       def augment_dataset(self, image, size):

           distorted_image = tf.image.random_brightness(image,

                                                        max_delta=63)

           distorted_image = tf.image.random_contrast(distorted_image,

                                                      lower=0.2, upper=1.8)

           # Subtract off the mean and divide by the variance of the pixels.

           float_image = tf.image.per_image_standardization(distorted_image)

           return float_image

       def parse_function(self, filename, label):

           label_ = tf.one_hot(label, self.num_classes)

           #img = tf.read_file(filename)

           img = tf.io.read_file(filename)

           img = tf.image.decode_jpeg(img, channels=3)

           img = tf.image.convert_image_dtype(img, dtype=tf.float32)

           #img =tf.random_crop(img,[self.image_size[0],self.image_size[1],3])      

           img=tf.image.random_crop(img, [self.image_size[0], self.image_size[1], 3])

           img = tf.image.random_flip_left_right(img)

           img = self.augment_dataset(img, self.image_size)

           return img, label_

    dataset = ImageData(txtfile, batch_size, num_classes, image_size)

    上图中标红色的地方是tensorFlow2.0版本与1.x版本的区别,红色部分属于1.X版本。主要更改在contrib部分,在tensorFlow2.0中已经删除contrib了,其中有维护价值的模块会被移动到别的地方,剩余的都将被删除,这点大家务必注意。

    如果您对上面代码有任何不明白的地方请移步之前的文章:【tensorflow速成】Tensorflow图像分类从模型自定义到测试

    重要活动,本周有三AI纪念扑克牌发售中,只有不到100套的名额噢,先到先得!

    总结

    本文主要介绍了如何在TensorFlow2.0中对自己的数据进行预处理。主要由两种比较好用的方法,第一种是TensorFlow2.0中特有的,即利用Keras高级API对数据进行预处理,第二种是利用Dataset类来处理数据,它和TensorFlow1.X版本基本一致。

    下期预告:使用TensorFlow构建深度学习网络。

    往期

    有三AI一周年了,说说我们的初衷,生态和愿景

    【TensorFlow2.0】TensorFlow2.0专栏上线,你来吗?

    【TensorFlow2.0】以后我们再也离不开Keras了?

    【TensorFlow2.0】数据读取与使用方式

    【TensorFlow2.0】如何搭建网络模型?

    如果想加入我们,后台留言吧

    转载文章请后台联系

    侵权必究

    技术交流请移步知识星球

    更多请关注知乎专栏《有三AI学院》和公众号《有三AI》

    相关文章

      网友评论

          本文标题:【TensorFlow2.0】数据读取与使用方式

          本文链接:https://www.haomeiwen.com/subject/fntdfctx.html