美文网首页
TF2.0:训练集、测试集的地址、标签获取完整流程!

TF2.0:训练集、测试集的地址、标签获取完整流程!

作者: 胜负55开 | 来源:发表于2020-05-21 16:30 被阅读0次

    TF2.0获得训练集、测试集的所有文件地址和对应的标签,下面是完整的一条龙操作,简单快捷好理解,获得的结果可直接作为tf.data的输入数据。

    1. 首先:训练集、测试集放到两个不同的文件夹里:
    图1:训练集、测试集放两个文件夹内

    每一个文件夹下又有airplane、lake两个文件夹:


    图2:每个文件夹内又有2个子文件夹,分别放对应的图
    2. 获得所有训练、测试数据的路径(字符串):
    import glob
    
     # train文件夹内所有文件夹都要,所有文件夹内的所有.jpg文件都要!
    train_data_path = glob.glob( 'E:/data/train/*/*.jpg' ) 
    # test文件夹内所有文件夹都要,所有文件夹内的所有.jpg文件都要!
    test_data_path = glob.glob( 'E:/data/test/*/*.jpg' )    
    
    # 查看一下:数据量、类型
    len(train_data_path), type(train_data_path[0])
    (1400, str)
    
    # 再查看一下:内容是什么
    train_data_path[0]
    'E:/data/train\\airplane\\airplane_001.jpg'
    
    3. 将数据全部打乱:
    import random 
    
    random.shuffle( train_data_path )
    random.shuffle( test_data_path )
    
    4. 获得训练集、测试集一共有哪些标签种类:
    # 纯标签有哪几种:把每个文件地址按\\分割,第2个元素就是标签!
    # set()获得无序不重复元素集!
    pure_train_labels = set( [ p.split('\\')[1] for p in train_data_path ] )
    pure_test_labels = set( [ p.split('\\')[1] for p in test_data_path ] )
    
    # 查看一下:
    pure_train_labels, pure_test_labels
    ({'airplane', 'lake'}, {'airplane', 'lake'})
    
    5. 把4获得的标签种类,转为数字索引的形式(字典):
    pure_train_labels_to_index = dict( (index, name) for (name, index) in enumerate(pure_train_labels) )
    pure_test_labels_to_index = dict( (index, name) for (name, index) in enumerate(pure_test_labels) )
    
    # 查看一下:
    pure_train_labels_to_index, pure_test_labels_to_index
    ({'airplane': 0, 'lake': 1}, {'airplane': 0, 'lake': 1})
    
    6. 获得所有图片的标签:
    # 还是把每个文件地址按\\分割,第2个元素就是标签!不需要用set(因为就要获得所有图片的)
    train_labels = [ p.split('\\')[1] for p in train_data_path ]
    test_labels = [ p.split('\\')[1] for p in test_data_path ]
    
    # 查看一下:前3个
    train_labels[0:3], test_labels[0:3]
    (['airplane', 'airplane', 'lake'], ['airplane', 'lake', 'airplane'])
    
    7. 把6获得的标签全转为对应的索引值:利用5中的字典!
    # 获取“键”对应的“(索引)值”:
    train_labels = [ pure_train_labels_to_index.get(label) for label in train_labels ]
    test_labels = [ pure_test_labels_to_index.get(label) for label in test_labels ]
    
    # 查看一下:还是前3个 —— 可与上面对比看对不对!
    ([0, 0, 1], [0, 1, 0])
    
    8. 地址、标签获取完毕;最后列出所有后面会用到的变量名:
    # 文件地址:后面tf.io.read_file( path )需要输入它们
    train_data_path, test_data_path
    
    # 测试集、训练集所有图像对应的标签:tf.data()需要!
    train_labels, test_labels
    

    补充:

    相关文章

      网友评论

          本文标题:TF2.0:训练集、测试集的地址、标签获取完整流程!

          本文链接:https://www.haomeiwen.com/subject/ccttahtx.html