美文网首页tensorflow相关tensorflow
关于Ryan Dahl的tensorflow-resnet中Co

关于Ryan Dahl的tensorflow-resnet中Co

作者: Traphix | 来源:发表于2016-07-01 15:38 被阅读760次

    本文旨在学习Ryan Dahl的tensorflow-resnet源码中的Config类的基本作用。因为,它真的真的很有趣

    tensorflow-resnet的repository中有个文件叫config.py,Config类就是在这个文件中被定义的。它能够很方便的实现基于tensorflow编写程序的不同参数在不同scope中的隔离管理(很拗口,一会儿上例子),Config类有以下几个特点:

    • 它可以被认为是包含了多个dict的list
    • 它的内部参数在不同variable scope中是“隔离”的

    说了那么多,相信谁都没看明白,那么举个栗子:

    c = Config()
    c['p1'] = 1
    c['p2'] = 1
    c['p3'] = 1
    # c['p1'] = 1, c['p2'] = 1, c['p3'] = 1, c['p4']不存在
    
    with tf.variable_scope('foo'):
        c['p1'] = 2
        c['p4'] = 2
        # c['p1'] = 2, c['p2'] = 1, c['p3'] = 1, c['p4'] = 2
    
        with tf.variable_scope('bar'):
            c['p2'] = 2
            # c['p1'] = 2, c['p2'] = 2, c['p3'] = 1, c['p4'] = 2
    
    with tf.variable_scope('baz'):
        c['p3'] = 2
        # c['p1'] = 1, c['p2'] = 1, c['p3'] = 2, c['p4']不存在
    
    # c['p1'] = 1, c['p2'] = 1, c['p3'] = 1, c['p4']不存在
    

    程序内各项参数在不同位置的取值我已经注释出来了,很明显,不同variable scope中的参数是隔离的,你在’foo‘中设置的参数在’baz‘不起作用,在’foo‘中新定义的参数在其他scope中看不到(但在’foo‘中的’bar‘内可以看到)。

    关于Config类的特点还有待挖掘,以上只是说明了它最基本的特点,下面给出Ryan Dahl的config.py的源码,你也可以去他的repository中看,这是链接

    # This is a variable scope aware configuation object for TensorFlow
    
    import tensorflow as tf
    
    FLAGS = tf.app.flags.FLAGS
    
    class Config:
        def __init__(self):
            root = self.Scope('')
            for k, v in FLAGS.__dict__['__flags'].iteritems():
                root[k] = v
            self.stack = [ root ]
    
        def iteritems(self):
            return self.to_dict().iteritems()
    
        def to_dict(self):
            self._pop_stale()
            out = {}
            # Work backwards from the flags to top fo the stack
            # overwriting keys that were found earlier.
            for i in range(len(self.stack)):
                cs = self.stack[-i]
                for name in cs:
                    out[name] = cs[name]
            return out
    
        def _pop_stale(self):
            var_scope_name = tf.get_variable_scope().name
            top = self.stack[0]
            while not top.contains(var_scope_name):
                # We aren't in this scope anymore
                self.stack.pop(0)
                top = self.stack[0]
    
        def __getitem__(self, name):
            self._pop_stale()
            # Recursively extract value
            for i in range(len(self.stack)):
                cs = self.stack[i]
                if name in cs:
                    return cs[name]
    
            raise KeyError(name)
    
        def set_default(self, name, value):
            if not (name in self):
                self[name] = value
    
        def __contains__(self, name):
            self._pop_stale()
            for i in range(len(self.stack)):
                cs = self.stack[i]
                if name in cs:
                    return True
            return False
    
        def __setitem__(self, name, value):
            self._pop_stale()
            top = self.stack[0]
            var_scope_name = tf.get_variable_scope().name
            assert top.contains(var_scope_name)
    
            if top.name != var_scope_name:
                top = self.Scope(var_scope_name)
                self.stack.insert(0, top)
    
            top[name] = value
    
        class Scope(dict):
            def __init__(self, name):
                self.name = name
    
            def contains(self, var_scope_name):
                return var_scope_name.startswith(self.name)
    
    
    
    # Test
    if __name__ == '__main__':
    
        def assert_raises(exception, fn):
            try:
                fn()
            except exception:
                pass
            else:
                assert False, "Expected exception"
    
        c = Config()
    
        c['hello'] = 1
        assert c['hello'] == 1
    
        with tf.variable_scope('foo'):
            c.set_default("bar", 10)
            c['bar'] = 2
            assert c['bar'] == 2
            assert c['hello'] == 1
    
            c.set_default("mario", True)
    
            with tf.variable_scope('meow'):
                c['dog'] = 3
                assert c['dog'] == 3
                assert c['bar'] == 2
                assert c['hello'] == 1
    
                assert c['mario'] == True
    
            assert_raises(KeyError, lambda: c['dog'])
            assert c['bar'] == 2
            assert c['hello'] == 1
    
    

    相关文章

      网友评论

        本文标题:关于Ryan Dahl的tensorflow-resnet中Co

        本文链接:https://www.haomeiwen.com/subject/bvnfjttx.html