美文网首页Tensorflow
学习笔记TF023:下载、缓存、属性字典、惰性属性、覆盖数据流图

学习笔记TF023:下载、缓存、属性字典、惰性属性、覆盖数据流图

作者: 利炳根 | 来源:发表于2017-06-09 11:49 被阅读58次

    确保目录结构存在。每次创建文件,确保父目录已经存在。确保指定路径全部或部分目录已经存在。创建沿指定路径上不存在目录。

    下载函数,如果文件名未指定,从URL解析。下载文件,返回本地文件系统文件名。如果文件存在,不下载。如果文件未指定,从URL解析,返回filepath 。实际下载前,检查下载位置是否有目标名称文件。是,跳过下载。下载文件,返回路径。重复下载,把文件从文件系统删除。

    import os
    import shutil
    import errno
    from lxml import etree
    from urllib.request import urlopen
    
    
    def ensure_directory(directory):
        directory = os.path.expanduser(directory)
        try:
            os.makedirs(directory)
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise e
    
    def download(url, directory, filename=None):
        if not filename:
            _, filename = os.path.split(url)
        directory = os.path.expanduser(directory)
        ensure_directory(directory)
        filepath = os.path.join(directory, filename)
        if os.path.isfile(filepath):
            return filepath
        print('Download', filepath)
        with urlopen(url) as response, open(filepath, 'wb') as file_:
            shutil.copyfileobj(response, file_)
        return filepath
    

    磁盘缓存修饰器,较大规模数据集处理中间结果保存磁盘公共位置,缓存加载函数修饰器。Python pickle功能实现函数返回值序列化、反序列化。只适合能纳入主存数据集。@disk_cache修饰器,函数实参传给被修饰函数。函数参数确定参数组合是否有缓存。散列映射为文件名数字。如果是'method',跳过第一参数,缓存filepath,'directory/basename-hash.pickle'。方法method=False参数通知修饰器是否忽略第一个参数。

    import functools
    import os
    import pickle
    
    def disk_cache(basename, directory, method=False):
        directory = os.path.expanduser(directory)
        ensure_directory(directory)
    
        def wrapper(func):
            @functools.wraps(func)
            def wrapped(*args, **kwargs):
                key = (tuple(args), tuple(kwargs.items()))
                if method and key:
                    key = key[1:]
                filename = '{}-{}.pickle'.format(basename, hash(key))
                filepath = os.path.join(directory, filename)
                if os.path.isfile(filepath):
                    with open(filepath, 'rb') as handle:
                        return pickle.load(handle)
                result = func(*args, **kwargs)
                with open(filepath, 'wb') as handle:
                    pickle.dump(result, handle)
                return result
            return wrapped
    
        return wrapper
    @disk_cache('dataset', '/home/user/dataset/')
    def get_dataset(one_hot=True):
        dataset = Dataset('http://example.com/dataset.bz2')
        dataset = Tokenize(dataset)
        if one_hot:
            dataset = OneHotEncoding(dataset)
        return dataset
    

    属性字典。继承自内置dict类,可用属性语法访问悠已有元素。传入标准字典(键值对)。内置函数locals,返回作用域所有局部变量名值映射。

    class AttrDict(dict):
    
        def __getattr__(self, key):
            if key not in self:
                raise AttributeError
            return self[key]
    
        def __setattr__(self, key, value):
            if key not in self:
                raise AttributeError
            self[key] = value
    

    惰性属性修饰器。外部使用。访问model.optimze,数据流图创建新计算路径。调用model.prediction,创建新权值和偏置。定义只计算一次属性。结果保存到带有某些前缀的函数调用。惰性属性,TensorFlow模型结构化、分类。

    import functools
    
    def lazy_property(function):
        attribute = '_lazy_' + function.__name__
    
        @property
        @functools.wraps(function)
        def wrapper(self):
            if not hasattr(self, attribute):
                setattr(self, attribute, function(self))
            return getattr(self, attribute)
        return wrapper
    
    class Model:
    
        def __init__(self, data, target):
            self.data = data
            self.target = target
            self.prediction
            self.optimize
            self.error
    
        @lazy_property
        def prediction(self):
            data_size = int(self.data.get_shape()[1])
            target_size = int(self.target.get_shape()[1])
            weight = tf.Variable(tf.truncated_normal([data_size, target_size]))
            bias = tf.Variable(tf.constant(0.1, shape=[target_size]))
            incoming = tf.matmul(self.data, weight) + bias
            return tf.nn.softmax(incoming)
    
        @lazy_property
        def optimize(self):
            cross_entropy = -tf.reduce_sum(self.target, tf.log(self.prediction))
            optimizer = tf.train.RMSPropOptimizer(0.03)
            return optimizer.minimize(cross_entropy)
    
        @lazy_property
        def error(self):
            mistakes = tf.not_equal(
                tf.argmax(self.target, 1), tf.argmax(self.prediction, 1))
            return tf.reduce_mean(tf.cast(mistakes, tf.float32))
    

    覆盖数据流图修饰器。未明确指定使用期他数据流图,TensorFlow使用默认。Jupyter Notebook,解释器状态在不同一单元执行期间保持。初始默认数据流图始终存在。执行再次定义数据流图运算单元,添加到已存在数据流图。根据菜单选项重新启动kernel,再次运行所有单元。
    创建定制数据流图,设置默认。所有运算添加到该数据流图,再次运行单元,创建新数据流图。旧数据流图自动清理。
    修饰器中创建数据流图,修饰主函数。主函数定义完整数据流图,定义占位符,调用函数创建模型。

    import functools
    import tensorflow as tf
    
    def overwrite_graph(function):
        @functools.wraps(function)
        def wrapper(*args, **kwargs):
            with tf.Graph().as_default():
                return function(*args, **kwargs)
        return wrapper
    @overwrite_graph
    def main():
        data = tf.placeholder(...)
        target = tf.placeholder(...)
        model = Model()
    
    main()
    

    API文档,编写代码时参考:
    https://www.tensorflow.org/versions/master/api_docs/index.html
    Github库,跟踪TensorFlow最新功能特性,阅读拉拽请求(pull request)、问题(issues)、发行记录(release note):
    https://github.com/tensorflow/tensorflow
    分布式 TensorFlow:
    https://www.tensorflow.org/versions/master/how_tos/distributed/index.html
    构建新TensorFlow功能:
    https://www.tensorflow.org/master/how_tos/adding_an_op/index.html
    邮件列表:
    https://groups.google.com/a/tensorflow.org/d/forum/discuss
    StackOverflow:
    http://stackoverflow.com/questions/tagged/tensorflow
    代码:
    https://github.com/backstopmedia/tensorflowbook

    参考资料:
    《面向机器智能的TensorFlow实践》

    欢迎付费咨询(150元每小时),我的微信:qingxingfengzi

    相关文章

      网友评论

        本文标题:学习笔记TF023:下载、缓存、属性字典、惰性属性、覆盖数据流图

        本文链接:https://www.haomeiwen.com/subject/yrwxqxtx.html