美文网首页程序员@IT·互联网
使用PAI深度学习读取OSS文件

使用PAI深度学习读取OSS文件

作者: 万千钧 | 来源:发表于2017-08-01 15:04 被阅读0次

    在PAI上, 使用TensorFlow读取OSS文件

    作者: 万千钧
    转载需注明出处

    本文适合有一定TensorFlow基础, 且准备使用PAI的同学阅读

    目录

    1. 如何使用PAI上读取OSS数据
    2. 如何使用PAI上写入数据到OSS
    3. 如何减少读取的费用开支
    4. 使用OSS需要注意的问题

    1. 在PAI上读取数据

    Python不支持读取oss的数据, 故所有调用python Open() os.path.exist() 等文件, 文件夹操作的
    函数的代码都无法执行.

    Scipy.misc.imread(),numpy.load()

    那如何在PAI读取数据呢, 通常我们采用两种办法.

    1. 如果只是简单的读取一张图片, 或者一个文本等, 可以使用tf.gfile下的函数, 具体成员函数如下
    tf.gfile.Copy(oldpath, newpath, overwrite=False) # 拷贝文件
    tf.gfile.DeleteRecursively(dirname) # 递归删除目录下所有文件
    tf.gfile.Exists(filename) # 文件是否存在
    tf.gfile.FastGFile(name, mode='r') # 无阻塞读取文件
    tf.gfile.GFile(name, mode='r') # 读取文件
    tf.gfile.Glob(filename) # 列出文件夹下所有文件, 支持pattern
    tf.gfile.IsDirectory(dirname) # 返回dirname是否为一个目录
    tf.gfile.ListDirectory(dirname) # 列出dirname下所有文件
    tf.gfile.MakeDirs(dirname) # 在dirname下创建一个文件夹, 如果父目录不存在, 会自动创建父目录. 如果
    文件夹已经存在, 且文件夹可写, 会返回成功
    tf.gfile.MkDir(dirname) # 在dirname处创建一个文件夹
    tf.gfile.Remove(filename) # 删除filename
    tf.gfile.Rename(oldname, newname, overwrite=False) # 重命名
    tf.gfile.Stat(dirname) # 返回目录的统计数据
    tf.gfile.Walk(top, inOrder=True) # 返回目录的文件树
    

    具体的文档可以参照这里(可能需要翻墙)

    1. 如果是一批一批的读取文件, 一般会采用tf.WhoFileReader()tf.train.batch() /
      tf.train.shuffer_batch()

    接下来会重点介绍常用的 tf.gfile.Glob, tf.gfile.FastGFile, tf.WhoFileReader()
    tf.train.shuffer_batch()

    读取文件一般有两步

    1. 获取文件列表
    2. 读取文件

    如果是批量读取, 还有第三步

    1. 创建batch

    从代码上手:
    在使用PAI的时候, 通常需要在右侧设置读取目录, 代码文件等参数, 这些参数都会通过--XXX的形式传入

    tf.flags可以提供了这个功能

    import tensorflow as tf
    
    FLAGS = tf.flags.FLAGS
    # 前面的buckets, checkpointDir都是固定的, 不建议更改
    
    tf.flags.DEFINE_string('buckets', 'oss://XXX', '训练图片所在文件夹')
    tf.flags.DEFINE_string('batch_size', '15', 'batch大小')
    
    # 获取文件列表
    
    files = tf.gfile.Glob(os.path.join(FLAGS.buckets,'*.jpg')) # 如我想列出buckets下所有jpg文件路径
    

    接下来就分两种情况了

    1. (小规模读取时建议) tf.gfile.FastGfile()
    for path in files:
        file_content = tf.gfile.FastGFile(path, 'rb').read() # 一定记得使用rb读取, 不然很多情况下都会报错
        image = tf.image.decode_jpeg(file_content, channels=3) # 本教程以JPG图片为例
    
    1. (大批量读取时建议) tf.WhoFileReader()
    reader = tf.WholeFileReader()  # 实例化一个reader
    fileQueue = tf.train.string_input_producer(files)  # 创建一个供reader读取的队列
    file_name, file_content = reader.read(fileQueue)  # 使reader从队列中读取一个文件
    image_content = tf.image.decode_jpeg(file_content, channels=3)  # 讲读取结果解码为图片
    label = XXX  # 这里省略处理label的过程
    batch = tf.train.shuffle_batch([label, image_content], batch_size=FLAGS.batch_size, num_threads=4,
                                   capacity=1000 + 3 * FLAGS.batch_size, min_after_dequeue=1000)
    
    sess = tf.Session()  # 创建Session
    tf.train.start_queue_runners(sess=sess)  # 重要!!! 这个函数是启动队列, 不加这句线程会一直阻塞
    labels, images = sess.run(batch)  # 获取结果
    

    现在解释下其中重要的部分

    1. tf.train.string_input_producer, 这个是把files转换成一个队列, 并且需要 tf.train.start_queue_runners 来启动队列
    2. tf.train.shuffle_batch 参数解释
    • batch_size 批大小, 每次运行这个batch, 返回多少个数据
    • num_threads 运行线程数, 在PAI上4个就好
    • capacity 随机取文件范围, 比如你的数据集有10000个数据, 你想从5000个数据中随机取, capacity就设置成5000.
    • min_after_dequeue 维持队列的最小长度, 这里只要注意不要大于capacity即可

    2. 写入数据

    1.直接使用tf.gfile.FastGFile()写入

    tf.gfile.FastGFile(FLAGS.checkpointDir + 'example.txt', 'wb').write('hello world')
    
    1. 通过tf.gfile.Copy()拷贝
    tf.gfile.Copy('./example.txt', FLAGS.checkpointDir + 'example.txt')
    

    通过这两种方法, 文件都会出现在 '输出目录/model/example.txt' 下

    3. 费用开支

    这里只讨论读取文件所需要的费用开支

    原则上来说, PAI不跨区域读取OSS是不收费的, 但是OSS的API是收费的. PAI在使用 tf.gile.Glob 的时候
    会产生GET请求, 在写入tensorboard的时候, 也会产生PUT请求. 这两种请求都是按次收费的, 具体价格如下

    标准型单价: 0.01元/万次

    低频访问型单价: 0.1元/万次

    归档型单价: 0.1元/万次

    当数据集有几十万图片, 通过tf.gile.Glob一次就需要几毛钱. 所以减少费用开支的方法就是减少GET请求次数

    这里给出几种解决思路

    1. 最好的解决思路, 把所有会使用到的数据, 一并上传传到OSS, 然后使用tensorflow拷贝到运行时目录, 最后通过tensorflow读取, 这样是最节省开支的.

    2. 通过tfrecords, 在本地, 提前把几十上百张图片通过tfrecords存下来, 这样读取的时候可以减少GET请求

    1. 把训练使用的图片随着代码的压缩包一起传上去, 不走OSS读取

    三种方法都可以显著的减少开支.

    4.使用中需要注意的

    事实上, 每次读取传过来的地址就是 oss://你的buckets名字/XXX, 本以为不需要在PAI界面上 设置, 直接读取这个目录就好, 事实上并不如此.

    PAI没有权限读取不在数据源目录和输出目录下的文件, 所以在使用路径前, 确保他们已经在控制台右侧设置过.

    右侧控制台截图右侧控制台截图

    OSS路径推荐使用
    FLAGS.checkpointDir
    FLAGS.summaryDIr
    这样的形式传入, 经过测试好像也只有这两个目录下有写权限
    FLAGS.buckets下所有文件夹都有读写权限

    相关文章

      网友评论

        本文标题:使用PAI深度学习读取OSS文件

        本文链接:https://www.haomeiwen.com/subject/iqmjlxtx.html