Spatial Transformer Networks实战

作者: 素娜93 | 来源:发表于2017-08-15 10:12 被阅读288次

    这几天在看Spatial Transformer Networks 空间变换网络,该网络的结构图如下所示,STN由三部分组成:分别为localization network,grid generator 和 bilinear sampler组成。

    STN

    其中U和V分别为输入/输出图片(也可以是feature map)

    Localization network: 局部网络,一般由全连接或者卷积神经网络加上回归层构成,输入为U, 输出为空间变换矩阵θ,如果是仿射变换,就是6个神经元,如果是射影变换就是8个神经元,可以参考这两篇博客:Creating a Gallery of Transformed Images以及图像的等距变换,相似变换,仿射变换,射影变换及其matlab实现

    grid generator:通过下面的变换,利用空间变换矩阵θ 在输入特征图上产生输出特征图像素应该被采样的坐标点。

    bilinear sampler:在输入特征图上,结合上一步产生的坐标,在对应位置进行双线性插值,得到输出特征图对应点的像素值!


    下面是STN的TensorFlow版本实现,这里简单记录一下编译的过程:

    1、下载 mnist_cluttered_60x60_6distortions.npz 数据集,放在data目录下;

    2、在utils文件夹下创建 __init__.py文件(空文件)即可,这样就将utils文件夹变成了Python模块,否则会提醒找不到 data_utils一系列模块。

    找不到 data_utils一系列模块

    3、修改main.py中root_dir、logs_dir、save_dir 和 vis_path的路径;

    4、在终端下进入工程:命令行输入python main.py 回车,程序开始运行。这时打开utils文件夹,会看到生成了4个对应的 .pyc文件。

    运行截图

    5、如果是在GPU上运行代码,可能还会出现显示的问题如下:同样在main.py文件import matplotlib.pyplot as plt 后边添加 plt.switch_backend(‘agg’)即可。

    显示出错

    6、用tensorboard进行可视化,将日志的地址指向程序日志输出的地址(/spatial_transformer_network/logs),进入工程目录如下输入命令:

    tensorboard可视化

    可以看到服务器端口6006已经在使用中,所以使用其它端口,复制 http://0.0.0.0:7001 到浏览器,就可以看到可视化图表了:

    Tensorboard可视化结果

    点击进入GRAPHS栏,可以看到上面程序TensorFlow计算图的可视化结果。

    Tensorboard计算图可视化结果

    7、程序运行结束后,可以看到生成了对应的文件:

    8、最后对第一下CPU和GPU的运行速度:可以看到两个速度差了6倍左右!

    左边为CPU,右边为GPU

    PS: 不知道是不是由于编辑器的问题,我的代码在终端可以训练,但是在本地不能单步调试,进入不了主函数。因为在调试过程中发现__name__ = {str}'main',和平常的类型不一致(__name__ = {str}'__main__'),最后将 if__name__ =='__main__': 修改为:if __name__ == 'main': 就可以了。

    简单记录一下,防止日后忘了,现在我要去分析结果啦!

    相关文章

      网友评论

      • 十一魔龙:不知道您最后的结果怎么样,为什么我从epoch_1.png往后的9宫格图片全是黑色的,并没有得到对应的transformation结果图片。

      本文标题: Spatial Transformer Networks实战

      本文链接:https://www.haomeiwen.com/subject/jflmrxtx.html