美文网首页
Tensorflow入门

Tensorflow入门

作者: 策马踏清风 | 来源:发表于2020-07-09 14:47 被阅读0次

    基本概念

    一、计算模型——计算图

    1.1基本概念

    1. 计算图Tensorflow最基本的概念,Tensorflow中所有的计算都会转为计算图上的节点。
    2. 张量Tensor,可以理解为多维数组(0-d tensor:标量,1-d tensor:向量,2-d tensor:矩阵)表明了数据结构。
    3. flow,体现了计算模型(张量之间通过计算相互转换的过程)。Tensorflow每个计算都是计算图上的节点,节点之间的边描述了计算之间依赖关系。

    1.2 基本过程

    1. 定义计算图中所有计算
    2. 执行计算

    1.3使用默认计算图

    # 引入TensorFlow
    import tensorflow as tf
    
    # 定义两个张量
    a = tf.constant([1.0, 2.0], name="a")
    b = tf.constant([2.0, 3.0], name="b")
    
    # Tensorflow会自动将定义的计算转为计算图上的节点
    result = a + b
    
    # 系统会自动维护一个默认的计算图
    # 可以通过tf.get_default_graph获取当前默认计算图
    #通过a.graph可以查看张量所属的计算图。因为我们没指定,所以是默认计算图
    print(a.graph is tf.get_default_graph)
    

    1.4生成新的计算图

    注:不同的计算图张量和运算都不会共享

    import tensorflow as tf
    
    # 生成新的计算图
    g1 = tf.Graph()
    # 设为默认
    whith g1.as_default():
        # 在计算图上定义变量'v',并且初始化为0
        v = tf.get_variable('v', initializer=tf.zeros_initalizer()(shape=[1]))
    
    g2 = tf.Graph()
    whith g2.as_default():
        # 另一个计算图上定义变量'v',初始化为1
        v = tf.get_variable('v', initializer=tf.ones_initializer()(shape=[1]))
    
    # 读取g1上的变量
    whith tf.Session(graph=g1) as sess:
        # 初始化图上的变量
        tf.global_variables_initializer().run()
        # 退出变量作用域(返回上层)并开启变量复用
        whith tf.variable_scope("", reuse=True):
            # 输出0
            print(sess.run(tf.get_variable("v")))
    
    # 在计算图 g2 中读取变量'v'的取值
    with tf.Session(graph=g2) as sess:
        tf.global_variables_initializer().run()
        with tf.variable_scope("", reuse=True):
            #输出为 1
            print(sess.run(tf.get_variable("v")))
    

    1.5指定运行计算的设备

    g = tf.Graph()
    with g.device('/gpu:0'):
        result = a + b
    

    1.6常用集合

    Tensorflow的图可以有效管理资源(张量、变量等)。其中自动维护的集合就是访问这些资源的有效手段

    集合名 集合内容 使用场景
    tf.GraphKeys.VARIABLES 所有变量 持久化TensorFlow模型
    tf.GraphKeys.TRAINABLE_VARIABLES 可学习的变量(一般指神经网络的参数) 模型迅雷、生成可视化内容
    tf.GraphKeys.SUMMARIES 日志生成相关的张量 TensorFlow计算可视化
    tf.GraphKeys.QUEUE_RUNNERS 处理输入的QuecueRunner 输入处理
    tf.GraphKeys.MOVING_AVERAGE_VARIABLES 所有计算了滑动平均值的变量 计算变量滑动平均值

    二、数据类型——张量

    2.1概念

    1. 张量可以简单理解为多维数组(矩阵),零阶表示标量(scalar),一节代表向量(vector),n阶代表矩阵。
    2. 张量并没有保存具体数字,保存的是的到这些数字的计算过程

    2.2创建张量

    import tensorflow as tf
    # tf.constant是一个计算,这个计算的结果为一个张量,保存在a中
    a = tf.constant([1.0, 2.0], name='a')
    b = tf.constant([2.0, 3.0], name='b')
    result - tf.add(a, b, name='add')
    print(result)
    # Tensor("add_2:0", shape=(2,), dtype=float32)
    
    1. 计算结果也是一个张量
    2. 张量的结构中有三个要素,名字(name是唯一标准)、维度(shape)、类型(type)
    3. 张量可以通过'node:src_output'的形式命名,其中node为节点名称,src_output表示当前张量来自节点的第几个输出。shape=(2,)是张量的维度信息,这个说明result是一个一维数组,且数组长度是2。第三个是类型,每个张量都有唯一的类型,Tensorflow的计算类型必须相同。

    类型不匹配的例子

    import tensorflow as tf
    a = tf.constant([1, 2], name='a')
    b = tf.constant([2.0, 3.0], name='b')
    result = a + b
    #类型不匹配报错:
    #ValueError: Tensor conversion requested dtype int32 for Tensor with dtype float32: 
    #'Tensor("b_1:0", shape=(2,), dtype=float32)'
    

    2.3tf.constant

    用于计算得出张量的方法,原型如下

    tf.constant(
        value,
        dtype=None,
        shape=None,
        name='Const',
        verify_shape=False
    )
    

    value是必填的

    value的数量必须小于shape代表的矩阵最大承受数量,少于则会用最后一个值填充

    tf.constant([1, 2, 3, 4, 5, 6], shape=[2, 3])
    

    2.4tf.Variable

    tf.Variable的运算结果也是一个张量

    # 使用正态分布的方式创建一个2*3的矩阵,随机元素标准差为2
    tf.Variable(tf.random_normal([2,3], stddev=2))
    

    2.5张量的作用

    1. 对中间计算结果的引用,提高代码可读性
    2. 计算图构造完成后,张量用来计算结果

    三、运行模型——会话

    Tensorflow中的会话(session)是来执行定义好的运算

    with tf.Session() as sess:
        #使用创建好的会话来计算关心的结果(默认图)
        sess.run()
    

    tf.Tensor.eval函数可以计算一个张量的取值

    sess = tf.Session()
    with sess.as_default():
        print(result.eval)
    
    #下面代码和上面功能相同
    sess = tf.Session()
    #下面两个命令相同
    print(sess.run(result))
    print(result.evsl(session=sess))
    sess.close()   #书本上没有这句
    

    tf.InteractiveSession函数可以省去将产生的会话注册为默认会花的过程。

    #交互式环境下直接构建默认会话
    sess = tf.InteractiveSession()
    print(result.eval())
    sess.close()
    

    相关文章

      网友评论

          本文标题:Tensorflow入门

          本文链接:https://www.haomeiwen.com/subject/qrkjcktx.html