TensorFlow的变量在编程中被广泛使用,用于表示共享持久的状态。本章讲解了TensorFlow创建、更新、管理tf.Variable的方法。
变量的使用基于tf.Variable类。一个tf.Variable便是一个可以通过操作而被读取和修改的张量。tf.keras就是使用了tf.Variable来保存模型参数。
变量创建
可以通过指定一个初始值来创建一个变量,这样变量就有和初始值一样的数据类型。可能是因为tf.Variable基于tf.Tensor,tf.Variable用起来与Tensor相似,也有dtype,shape属性,也可以转换为NumPy。
my_tensor = tf.constant([[1.0, 2.0], [3.0, 4.0]])
my_variable = tf.Variable(my_tensor)
bool_variable = tf.Variable([False, False, False, True])
complex_variable = tf.Variable([5 + 4j, 6 + 1j])
print(my_variable)
print(bool_variable)
print(complex_variable)
print("my_variable Shape:", my_variable.shape)
print("my_variable DType:", my_variable.dtype)
print("my_variable As NumPy:", my_variable.numpy())
运行结果为:
<tf.Variable 'Variable:0' shape=(2, 2) dtype=float32, numpy=
array([[1., 2.],
[3., 4.]], dtype=float32)>
<tf.Variable 'Variable:0' shape=(4,) dtype=bool, numpy=array([False, False, False, True])>
<tf.Variable 'Variable:0' shape=(2,) dtype=complex128, numpy=array([5.+4.j, 6.+1.j])>
my_variable Shape: (2, 2)
my_variable DType: <dtype: 'float32'>
my_variable As NumPy: [[1. 2.]
[3. 4.]]
对张量的操作大多可用于变量,包括reshape操作(尽管变量不会被reshape)
正如前面提到的,变量的存储是用张量。你可以使用tf.Variable.assign对一个变量重新赋值。调用assgin并不会生成一个新的张量,而是将已经存在的张量的内存进行重用。但是在进行assgin时,是不能对变量的shape做出改变的,否则将报错。对变量的操作,实际上是对变量指向的张量内存的操作。通过变量来创建新的变量时,并不会共享原变量的内存,而是新建了一个相同的张量,然后让新变量指向新的张量。
生命周期,命名和监控
Python开发TensorFlow时,tf.Variable实例用友与Python对象一样的生命周期。当没有引用指向变量时,变量就会被自动析构回收。
变量可以命名,这可以帮助你在程序运行时追踪和调试时观察变量。变量的名称在模型保存和重载时将一直存在。默认情况下,模型的参数变量将自动获得一个唯一的命名,若非必要,则无需指定名称。
尽管变量对于微分来说是重要的,但是有些变量是不需要被微分的。你可以在变量创建时通过设置training开关来控制是否要对一个变量进行微分。比如,一个计步器是不需要被微分的。
step_counter = tf.Variable(1, trainable=False)
变量和张量的预先分配
TensorFlow为了提供更好的性能,TensorFlow将根据dtype将张量和变量分配至处理最快的设备,这意味着大多数的变量将会分配至GPU(若是存在的话)。但是,你仍然可以干涉这个动作。下面的代码片段将一个浮点型的张量和变量强行分配至CPU(即使GPU存在)。通过开启设备分配日志,你可以看到变量和张量被分配到的设备。
变量或张量存储在一个设备上,而计算却在另外的设备上的情况时有发生。这种情况下,设备之间的数据会进行拷贝而导致延迟。若是强行将计算运行在数据存储的设备上,则可以有效减少数据的拷贝。
with tf.device('CPU:0'):
a = tf.Variable([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
b = tf.Variable([[1.0, 2.0, 3.0]])
with tf.device('GPU:0'):
# Element-wise multiply
k = a * b
print(k)
网友评论