TensorFlow笔记
NumPy数组
TensorFlow和Numpy的dtype属性是完全一致的,TensorFlow可以从NumPy中完美导入字符串数组,但不能在NumPy中显式指定dtype的属性。
将张量的某一维的值指定为None,可使该维具有可变长度。而将形状指定为None,可使该张量拥有任意维数,且每一维具有任意长度。
Graph对象
应创建新的数据流图将默认的数据流图忽略,或者获取默认数据流图的句柄并使用,而不是将默认数据流图和创建的数据流图混合使用。
Session对象
tf.Session()接受三个可选参数:target,graph,config。
Target是执行引擎,用于分布式设置。
graph指定加载的Graph对象,默认值None,表示当前默认的数据流图。当有多个数据流图时,应显式传入Graph对象,而不是在with语句块内创建Session对象。
config是所需选项。
Session.run()方法接受一个参数fetches,以及三个可选参数feed_dict、options、run_metadata。fetches可为Tensor对象(返回NumPy数组)或神经元(返回None),也可以是变量的列表(返回对应值的列表)
feed_dict一般用于占位符的添加输入或规模较大数据流图的部分测试(不用重新计算一遍)
Session使用方法:
sess = tf.Session()
...
sess.close()
or
with tf.Session() as sess:
...
placeholder占位符
feed_dict为字典对象
调用tf.placeholder()时,dtype是必须指定的,shape参数和name标识符是可选的。
Variable对象
Saver对象
保存模型
#模型定义...
#创建Saver对象
saver = tf.train.Saver
with tf.Session() as sess:
#训练...
saver.saver(sess, 'my-model', global_step=training_steps)
sess.close()
恢复模型
with tf.Session() as sess:
#验证之前是否已经保存了检查点文件
ckpt = tf.train.get_checkpoint_state(os.path.dirname(__file__))
if ckpt and ckpt.model_checkpoint_path:
#从检查点恢复模型
saver.restore(sess, ckpt.model_checkpoint_path)
initial_step = int(ckpt.model_checkpoint_path.replit('-', 1)[1])
CNN
简单CNN架构包含卷积层(tf.nn.conv2d)、非线性变换层(tf.nn.relu)、池化层(tf.nn.max_pool)和全连接层(tf.nn.matmul)。
网友评论