风格迁移--生成你想要的风格
标签: pytorch
随着深度网络的流行,用AI作画也不再是问题,比如下面这一张:
![](https://img.haomeiwen.com/i2899252/d5542b2cf60b39e2.jpg)
你能看出来是手画的,还是自动生成的吗。
下面介绍一个风格迁移网络,能够帮你生成任意你想要的style。本文也会提供一个Starry_Night_Over_the_Rhone
的style模型,大家可以自己后台回复style_transform
获取代码和模型。
下面简单介绍一下风格迁移网络。
网络结构
![](https://img.haomeiwen.com/i2899252/f58d7184d16a7147.jpg)
上图就是快速风格迁移网络的结构,左边虚线框里面是一个Encoder-Decoder
结构,而右边整个就是一个训练好的vgg,主要用来做特征提取进而能够计算图片间的损失。
从图中可以看出输入是一个x
,经过Image Transform Net
会变为一个y^
,而这个y^
就是我们要的图片,也就是经过风格转换后的图片。比如我们输入一张东方明珠电视塔图片作为x
,那么文章刚开始的那个图片就是作为y^
,那这个y^
是如何学习得到的呢。主要靠后面vgg网络做损失,然后指导前面的Image Transform Net
学习。
下面介绍一下这个网络中最重要的损失函数,这个损失函数不同于之前的分类网络的损失,原来的分类网络一般就是一个交叉熵函数,但是这里的损失是一个预训练好的vgg,从图中可以看出,Loss Network
有三个输入,分别是ys y^ yc
,其中ys
就是风格图片,在本次实验中我们选择的是:
![](https://img.haomeiwen.com/i2899252/6dc60fd226df55b5.jpg)
正是由于ys
的缘故,所以我们的y^
在风格上和它非常像。而yc
其实就是x
。将这三个输入到vgg里面,然后计算利用vgg强大的特征提取能力,把提取的特征做为损失,我们的目的是使得我们y^
在内容上和yc
相近,而风格上和ys
更近,所以引出了两类损失,第一类是风格损失,第二类是内容损失,对应图中内容损失就直接对y^ yc
的中间特征用mse计算即可,也就是右边下面的那个损失,而风格损失是上面的三个,是对ys y^
的中间特征计算gram得到。
最后我们优化这两个损失就能保证我们的输出y^
在风格和ys
更近,而内容上和yc
更近。
代码简析
这部分对代码做一个简单的分析,其中main.py
是主函数,里面包含了两个主要方法train、stylize
,其中train是用来训练模型的, 如果你有充分的数据集,你可以自己加载数据来进行训练,只需要修改Config里面data_root
即可。
train里面主要就是加载数据,加载模型TransformerNet
,而TransformerNet
就是前面说的那个Image Transform Net
,损失网络同样使用的是Vgg
,在训练的过程中只更新TransformerNet
的参数,因为Vgg
是作为一个损失函数来用的,它直接使用一个ImageNet的预训练参数即可。
而stylize函数则提供了一个测试,当我们训练好了模型,就可以用这个函数来帮我们生成图片了,我们在Config里面指定一个content_path
,这里我们可以假定是一个东方明珠,你可以用其他代替。在stylize里面做的事情就是把TransformerNet
加载一下,注意要把训练好的模型给加载上去,然后一次前向传播即可。
如果你想直接用
如果你不太懂,而想尝下鲜,那么下载完代码后,直接输入python main.py
即可,如果想换个图片,直接修改第40行的content_path = 'images.jpeg' # 需要进行分割迁移的图片
即可。
欢迎大家关注我的微信公众号,未来上面会推送python
机器学习
算法学习
深度学习
论文阅读
以及偶尔的小鸡汤
等内容。ようこそいらっしゃい!
搜索 coderwangson 关注
![](https://img.haomeiwen.com/i2899252/6f3c61584a4b22c4.jpg)
网友评论