美文网首页
tensorflow加载以目录标识的数据集

tensorflow加载以目录标识的数据集

作者: WYCWGTDDR | 来源:发表于2020-10-13 09:17 被阅读0次

    对于图片分类网络的训练,往往将照片按类别标签存放在相应目录下,在神经网络训练时,可以使用tensorflow提供的flow_from_directory方法加载,但为了提高数据加载的性能,及使用更加强大的图像增强方法,尝试使用tensorflow.data.Dataset进行数据加载和预处理。

    import tensorflow as tf
    import glob
    import io
    
    gpus = tf.config.experimental.list_physical_devices('GPU')
    tf.config.experimental.set_visible_devices(gpus[1], 'GPU')# 指定第2块GPU可用  
    tf.config.experimental.set_memory_growth(device=gpus[1], enable=True)# 按需取用显存
    
    train_dir = "/ai/jzclass/train"
    valid_dir = "/ai/jzclass/valid"
    
    #读取文件列表和对应的标签目录列表
    train_image_path=glob.glob(train_dir + '/*/*.jpg')
    train_image_label=[p.split("/")[4] for p in train_image_path]
    
    #通过目录读取类别列表,并转换为字典
    label_names = os.listdir(train_dir)
    label_to_index = dict((name, index) for index, name in enumerate(label_names))
    
    #将以文本标识的列表转换为数字标签(y值)
    all_image_labels = [label_to_index[path] for path in train_image_label]
    
    #数据加载及增强方法
    def load_preprosess_image(path,label):
        image=tf.io.read_file(path)
        image=tf.image.decode_jpeg(image,channels=3) #有坑,可以用opencv代替
        image=tf.image.resize(image,[360,360])
        image=tf.image.random_crop(image,[224,224,3])
        image=tf.image.random_flip_left_right(image)
        image=tf.image.random_flip_up_down(image)
        image=tf.image.random_brightness(image,0.5)
        image=tf.image.random_contrast(image,0,1)
        image=tf.cast(image,tf.float32)
        image=image/255
        label=tf.reshape(label,[-1])
        return image,label
    
    #数据加载
    train_image_ds=tf.data.Dataset.from_tensor_slices((train_image_path,all_image_labels))
    AUTOTUNE=tf.data.experimental.AUTOTUNE
    train_image_ds=train_image_ds.map(load_preprosess_image,num_parallel_calls=AUTOTUNE)
    train_count=len(train_image_path)
    
    #乱序和预处理
    train_image_ds=train_image_ds.shuffle(train_count).batch(BATCH_SIZE)
    train_image_ds=train_image_ds.prefetch(AUTOTUNE)
    

    后续就可以直接使用model.fit(train_image_ds, epochs = 100)进行调用。

    相关文章

      网友评论

          本文标题:tensorflow加载以目录标识的数据集

          本文链接:https://www.haomeiwen.com/subject/gdwmpktx.html