美文网首页我爱编程
tensorflow学习笔记-会话机制(session)

tensorflow学习笔记-会话机制(session)

作者: 听风1996 | 来源:发表于2018-05-08 15:52 被阅读259次

    在TensorFlow中,有两种用于运行计算图(graph)的会话(session)

    • tf.Session( )

    • tf.InteractivesSession( )

    1. tf.Session( )

    要使用tf,我们必须先构建(定义)graph,之后才能运行graph。

    1.1 非交互式会话中的例子

    import tensorflow as tf
    
    # 构建graph
    a = tf.add(3, 5) 
    
    # 运行graph
    sess = tf.Session()  # 创建tf.Session的一个对象sess
    print(sess.run(a)) 
    
    sess.close()         # 关闭sess对象
    

    一个session可能会占用一些资源,比如变量、队列和读取器(reader)。我们使用sess.close()关闭会话或者使用上下文管理器释放这些不再使用的资源。

    1.2 建议的tf.Session( )写法

    import tensorflow as tf  
    
    # 构建graph
    matrix1 = tf.constant([[3., 3.]])  
    matrix2 = tf.constant([[2.], [2.]])  
    
    product = tf.matmul(matrix1, matrix2)  
    
    # 运行graph
    with tf.Session() as sess:          # 使用"with"语句,自动关闭会话
        print(sess.run(product))  
    

    1.3 Fetch(取回)

    在使用sess.run( )运行图时,我们可以传入fetches,用于取回某些操作或tensor的输出内容。fetches可以是list,tuple,namedtuple,dict中的任意一个。fetches可以是一个列表,在op的一次运行中一起获得(而不是逐个去获取 tensor)多个tensor值。

    import tensoflow as tf
    from collections import namedtuple
    
    a = tf.constant([10, 20])
    b = tf.constant([1.0, 2.0])
    MyData = namedtuple('MyData', ['a', 'b'])
    
    with tf.Session() as sess:
        c = sess.run(a)            # fetches可以为单个数a
        d = sess.run([a, b])       # fetches可以为一个列表[a, b]
        v = sess.run({'k1': MyData(a, b), 'k2': [b, a]}) 
    
        print(c)
        print(d)
        print(d[0])
        print(v) 
    '''
    v is a dict and v['k1'] is a MyData namedtuple with the numpy array [10, 20] and the numpy array [1.0, 2.0]. v['k2'] is a list with the numpy array [1.0, 2.0] and the numpy array [10, 20].
    '''
    

    1.4 Feed(注入)

    TensorFlow提供了feed注入机制, 它可以临时替代graph中任意op操作的输入tensor,可以对graph中任何操作提交补丁(直接插入一个tensor)。

    feed机制只在调用它的方法内有效,方法结束,feed就会消失。最常见的用例是把某些特殊操作为feed注入的对象。你可以提供数据feed_dict,作为sess.run( )调用的参数。使用tf.placeholder( ),为某些操作的输入创建占位符。

    import tensorflow as tf
    import numpy as np
    
    x = np.ones((2, 3))
    y = np.ones((3, 2)) 
    
    input1 = tf.placeholder(tf.int32)
    input2 = tf.placeholder(tf.int32)
    
    output = tf.matmul(input1, input2)
    
    with tf.Session() as sess:
        print(sess.run(output, feed_dict = {input1:x, input2:y}))
    

    如果没有正确提供tf.placeholder( ),feed操作将产生错误。注意,feed注入的值不能是tf的tensor对象,应该是Python常量、字符串、列表、numpy ndarrays,或者TensorHandles。

    1.5 分布式训练

    从version 0.8之后,TensorFlow开始支持分布式计算的机器学习,而且TensorFlow会充分利用CPU、GPU等计算资源。如果检测到GPU,TensorFlow会优先使用GPU运行程序。用字符串标识设备,目前支持的设备包括:

    “/cpu:0”:机器的第一个CPU。

    “/gpu:0”:机器的第一个GPU, 如果有的话

    “/gpu:1”:机器的第二个GPU, 以此类推

    当计算机有多个GPU时,通过tf.device( ),我们可以指定用哪个GPU来执行。代码示例如下:

    # 在with tf.device()下,构建graph
    with tf.device("/gpu:0"):
        a = tf.constant([[3., 3.]])
        b = tf.constant([[2.], [2.]])
        product = tf.matmul(a, b)
    
    # 运行graph
    with tf.Session() as sess:    
        print(sess.run(product))
    

    2. tf.InteractivesSession( )

    当python编辑环境是shell、IPython等交互式环境时,我们使用类tf.InteractiveSession代替类tf.Session,用方法tensor.eval( ),operation.run( ) 代替sess.run( ),这样可避免用一个变量sess来持有会话。其中更多地使用 tensor.eval(),所有的表达式都可以看作是tensor。

    // 进入python3交互式环境
    # python3
    
    >>> import tensorflow as tf  
    
    // 进入一个交互式会话
    >>> sess = tf.InteractiveSession()
    
    >>> a = tf.constant(5.0)
    >>> b = tf.constant(6.0)
    >>> c = a * b
    
    // We can just use 'c.eval()' without passing 'sess'
    >>> print(c.eval()) 
    
    >>> sess.close()   // 关闭交互式会话
    
    >>> exit()        // 退出python3交互式环境
    

    相关文章

      网友评论

        本文标题:tensorflow学习笔记-会话机制(session)

        本文链接:https://www.haomeiwen.com/subject/hhcwrftx.html