TensorFlow是谷歌2015年开源的用于深度学习算法开发的一套目前非常流行的框架,不仅在学术界而且在工业界都非常收欢迎。用TensorFlow实现一个神经网络十分的简洁,但是往往有时候你自己实现出来的神经网络的表现效果和预期有差别,或者说模型不工作。那么如何进行Debug,我自己最近有了一些小的体会,分享给大家。
本文主要分为两个部分,一部分是使用sess.run()方法获得模型运行期间的各个tensor的值,另一部分是使用TensorFlow自带的Debug工具tfdbg。
1.获取模型中tensor的值
在tensorflow构建的模型中,数据是以tensor的形式的存储和计算的。tensor我们主要关心的是它的shape和具体的value这两个属性,通过这两个属性我们可以判断我们模型是不是按照我们预先设想的方式进行运作的。通过sess.run()中fetch你指定的tensor,可以拿到我们想要观察的tensor。如下:
线性模型拟合demo上面是一个简单的拟合一个线性模型的demo,我们通过fetch+sess.run()的方式获得了每次训练完成后的b和w,并每隔50次迭代将其type,shape和value打印出来,下面是运行结果:
step: 0 type of b:shape of b: (1,) value of b: [ 0.67431259]type of w:shape of w: (1, 2) value of w: [[-0.33759621 0.51075733]]
step: 50 type of b:shape of b: (1,) value of b: [ 0.30864805]type of w:shape of w: (1, 2) value of w: [[ 0.08913503 0.19295232]]
step: 100 type of b:shape of b: (1,) value of b: [ 0.30055302]type of w:shape of w: (1, 2) value of w: [[ 0.0994734 0.19937272]]
step: 150 type of b:shape of b: (1,) value of b: [ 0.30003637]type of w:shape of w: (1, 2) value of w: [[ 0.0999667 0.19995736]]
step: 200 type of b:shape of b: (1,) value of b: [ 0.3000024]type of w:shape of w: (1, 2) value of w: [[ 0.09999783 0.1999972 ]]
通过运行结果我们可以看出,每次通过sess.run返回的tensor都是一个numpy数组,这样很方便我们观察每个tensor的值,shape等,然后确定模型是不是按照我们想要的方式进行运作的。除了通过sess.run(),还可以通过tf.Print和tf.Assert 的方式来观察我们想要的tensor,但是我个人还是比较喜欢用sess.run()的方式。
2.TensorFlow Debugger (tfdbg)
TensorFlow Debugger是TensorFlow自带的CLI调试工具tfdbg,这是一个很方便快捷的调试工具,官方的文档中给出了一个基于mnist数据集的一个例子,解释的很详细。下面我们还是在上面的拟合线性模型的demo上进行演示,代码如下:
tfdbg的一个demo我们可以看到tfdbg其实很简单只需要将session包装成一个用于调试的LocalCLIDebugWrapperSession类即可,这样在session run的时候就会自动的启动tfdbg的命令行界面,进行包装的代码很简单只有一行:
将sess包装成一个用于调试的LocalCLIDebugWrapperSession类上面代码在tfdbg运行时的命令行界面:
初始界面 sess.run()界面通过tfdbg我们可以很方便直观的看到每一次sess.run时候的数据图中各个节点的情况,包括我们想要观察的b和w。tfdbg的功能和用法还有很多,根据需要可以自己去看官方的文档TensorFlow Debugger (tfdbg) Command-Line-Interface。
3.总结
TensorFlow是一个很高级的框架,我们用它来实现深度学习算法很简洁,而且不需要考虑很底层的实现,这样有利有弊。好处是提高了开发的效率,坏处是不清楚底层的实现,有时候出现问题不能很好的找出bug所在。所以通过对模型算的debug的过程,也是进一步加深了自己对TensorFlow的理解。
个人觉得学习TensorFlow最好的方式就是自己找一个开源深度学习任务,然后自己将其复现出来,复现一定要做到指标和原有模型对齐,再这个过程中肯定会遇到各种各样的bug,但是程序员不就为了bug而生么,慢慢debug的过程,也就对TensorFlow有了更深的理解啦。
网友评论