美文网首页TensorFlow2简单入门
TensorFlow2简单入门-图像加载及预处理

TensorFlow2简单入门-图像加载及预处理

作者: K同学啊 | 来源:发表于2021-01-16 10:18 被阅读0次

    作者:明天依旧可好


    下载数据

    import tensorflow as tf
    
    import pathlib
    data_root_orig = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
                                             fname='flower_photos', untar=True)
    data_root = pathlib.Path(data_root_orig)
    print(data_root)
    """
    输出:
    C:\Users\Administrator\.keras\datasets\flower_photos
    """
    

    可以通过C:\Users\Administrator.keras\datasets\flower_photos路径查找到下载的文件

    #查看数据目录
    for item in data_root.iterdir():
        print(item)
    """
    输出:
    C:\Users\Administrator\.keras\datasets\flower_photos\daisy
    C:\Users\Administrator\.keras\datasets\flower_photos\dandelion
    C:\Users\Administrator\.keras\datasets\flower_photos\LICENSE.txt
    C:\Users\Administrator\.keras\datasets\flower_photos\roses
    C:\Users\Administrator\.keras\datasets\flower_photos\sunflowers
    C:\Users\Administrator\.keras\datasets\flower_photos\tulips
    """
    

    flower_photos文件夹下包括5个文件夹和一个说明文件,5个文件夹中分别放有5个类别的数据(即对应着5种不同的标签。)

    import random
    #获取所有图片的路径
    all_image_paths = list(data_root.glob('*/*'))
    all_image_paths = [str(path) for path in all_image_paths]
    #将所有路径打乱
    random.shuffle(all_image_paths)
    
    image_count = len(all_image_paths)
    image_count
    """
    输出:3670
    """
    all_image_paths[:3]
    """
    输出:
    ['C:\\Users\\Administrator\\.keras\\datasets\\flower_photos\\daisy\\11870378973_2ec1919f12.jpg',
     'C:\\Users\\Administrator\\.keras\\datasets\\flower_photos\\roses\\8442304572_2fdc9c7547_n.jpg',
     'C:\\Users\\Administrator\\.keras\\datasets\\flower_photos\\dandelion\\17574213074_f5416afd84.jpg']
    """
    

    检查图片

    from PIL import Image
    import os
    
    train_images = []
    for image in all_image_paths[]:
        train_images.append(Image.open(os.path.join(image)))
    

    将图片与标签同步从本地文件中拿出来。

    import matplotlib.pyplot as plt
    
    train_labels  = [pathlib.Path(path).parent.name for path in all_image_paths]
    
    plt.figure(figsize=(20,10))
    for i in range(20):
        plt.subplot(5,10,i+1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(train_images[i])
        plt.xlabel(train_labels[i])
    plt.show()
    

    构建一个 tf.data.Dataset

    train_ds = tf.keras.preprocessing.image_dataset_from_directory(
      data_root,
      validation_split=0.2,
      subset="training",
      seed=123,
      image_size=(192, 192),
      batch_size=20)
    
    class_names = train_ds.class_names
    print("\n",class_names)
    
    train_ds
    """
    输出:
    Found 3670 files belonging to 5 classes.
    Using 2936 files for training.
    
     ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
    <BatchDataset shapes: ((None, 192, 192, 3), (None,)), types: (tf.float32, tf.int32)>
    """
    

    train_ds = tf.keras.preprocessing.image_dataset_from_directory():将创建一个从本地目录读取图像数据的数据集。数据集对象可以直接传递到fit(),也可以在自定义低级训练循环中进行迭代。

    import matplotlib.pyplot as plt
    
    plt.figure(figsize=(20, 10))
    for images, labels in train_ds.take(1):
        for i in range(20):
            ax = plt.subplot(5, 10, i + 1)
            plt.imshow(images[i].numpy().astype("uint8"))
            plt.title(class_names[labels[i]])
            plt.axis("off")
    
    • dataset.take(1):取第一个元素构建dataset(是第一个元素,不是随机的一个),从文件中读取数据形成train_ds时是以为20为一个步长的,故这里的dataset.take(1)即前20个数据。
    • dataset.skip(2):跳过前2个元素后构建的dataset
    for image_batch, labels_batch in train_ds:
        print(image_batch.shape)
        print(labels_batch.shape)
        break
    """
    输出:
    (30, 192, 192, 3)
    (30,)
    """
    

    相关文章

      网友评论

        本文标题:TensorFlow2简单入门-图像加载及预处理

        本文链接:https://www.haomeiwen.com/subject/thtnaktx.html