TensorFlow 是用数据流图(data flow graph)做计算的,它由节点(node)和边(edge)组件的有向无环图(directed acycline graph,DAG)。
节点表示计算单元,而边表示被计算单元消费或生产的数据。在 tf.Graph
的上下文中,每个 API 的调用定义了 tf.Operation
(节点),每个节点可以有零个或多个输入和输出的 tf.Tensor
(边)。
比如,定义 Python 变量 x
:
g = tf.Graph()
with g.as_default():
x = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
print('name is', x.name)
print(x)
输出:
name is Const:0
Tensor("Const:0", shape=(2, 2), dtype=float32)
这里 x
定义了一个名叫 Const
的新节点(tf.Operation
)加入到从上下文集成下类的默认 tf.Graph
中。该节点返回一个名称为 Const:0
的 tf.Tensor
(边)。
由于 tf.Graph
中每个节点都是唯一的,如果依据在图中存在一个名称为 Const
的节点(这是所有 tf 常量的默认名称),TensorFlow 将在名称上添加后缀 _1
、_2
等使其名称唯一。当然,也可以自定义名称。
g = tf.Graph()
with g.as_default():
x = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
x1 = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
s = tf.constant([[1, 2], [3, 4]], name='SG')
print(x.name, x1.name, s.name)
输出:
Const:0 Const_1:0 SG:0
输出的 tf.Tensor
和其相关的 tf.Operation
名称相同,但是加上了 :ID
形式的后缀。这个 ID 是一个递增的整数,表示该运算产生了多少个输出。但是可以存在有多个输出的运算,这种情况下,:0
,:1
等后缀会被加到由该运算产生的 tf.Tensor
名字后。
也可以通过调用 tf.name_scope
函数定义的一个上下文,为该上下文中所有的运算添加命名范围前缀。这个前缀是用 /
分割的一个名称列表:
g = tf.Graph()
with g.as_default():
with tf.name_scope('A'):
x = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
y = x
print(x.name)
with tf.name_scope('B'):
x = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
z = x + y
print(x.name, z.name)
输出
A/Const:0
A/B/Const:0 add:0
也可以这样:
g1 = tf.Graph()
g2 = tf.Graph()
with g1.as_default():
with tf.name_scope('A'):
x = tf.constant(1, name='x')
print(x)
with g2.as_default():
with tf.name_scope('B'):
x = tf.constant(1, name='x')
print(x)
输出:
Tensor("A/x:0", shape=(), dtype=int32)
Tensor("B/x:0", shape=(), dtype=int32)
图放置——tf.device
tf.device
创建一个和设备相符的上下文管理器。这个函数允许使用者请求将一个上下文创建的所有运算放置在相同的设备上。由 tf.device
指定的设备不仅仅是物理设备。它可以是远程服务器、远程设备、远程工作者即不同种类的物理设备(GPU、CPU、TPU)。它需要遵照一个设备的指定规范才能正确地告知框架来使用所需设备。一个设备指定规范有如下形式:
"/job:<JOB_NAME>/task:<TASK_INDEX>/device:<DEVICE_TYPE>:<DEVICE_INDEX>"
-
<JOB_NAME>
:是一个由字母和数字构成的字符串,首字母不能是数字; -
<DEVICE_TYPE>
:是一个已注册过的设备类型(CPU或GPU); -
<TASK_INDEX>
:是一个非负整数,代表了名为<JOB_NAME>
的工作中的任务编号;
with tf.device('/job:foo'):
# ops created here have devices with /job:foo
with tf.device('/job:bar/task:0/device:gpu:2'):
# ops created here have the fully specified device above
with tf.device('/device:gpu:1'):
# ops created here have the device '/job:foo/device:gpu:1'
边
TensorFlow 的边有两种连接关系:数据依赖与控制依赖。其中实线边表示数据依赖,代表数据,即张量。虚线边表示控制依赖(control dependency),可以用于控制操作的运行,这被用来确保 happens-before 关系,这类边上没有数据流过,但源节点必须在目的节点开始前完成执行。
g = tf.Graph()
g.control_dependencies(control_inputs)
节点
graph 中的节点又称为算子,它代表一个操作(tf.Operation
),一般用来表示施加的数学运算,数据输入的起点及输出的终点,或者是读取/写入持久变量(persistent variable)的终点。
网友评论