风格迁移

作者: LuDon | 来源:发表于2019-07-08 17:46 被阅读0次

    简介

    使用TensorFlow实现快速风格迁移(fast neural style transfer), 参考论文[1]

    原理

    根据内容图片和风格图片优化输入图片,使得内容损失和风格损失尽可能小,快速图像风格迁移网络结构如下所示:


    image

    其中风格图片时固定的,而内容是可变的,即将任意输入图片转换为指定风格的风格图片。

    • 转换网络:参数要训练,将内容图片转换成迁移图片
    • 损失网络:计算迁移图片和风格图片之间的风格损失,以及迁移图片和原始图片之间的内容损失。
      经过训练后,转换网络所生成的迁移图片,在内容上和输入的内容图片相似,在风格上和指定的风格图片相似。

    代码

    1、使用vgg19作为损失网络:

    def conv_(inputs, w, b):
        return tf.nn.conv2d(inputs, w, [1, 1, 1, 1], "SAME") + b
    
    def max_pooling(inputs):
        return tf.nn.max_pool(inputs, [1, 2, 2, 1], [1, 2, 2, 1], "SAME")
    
    def vggnet(inputs, vgg_path='/home/luodan/project/fast_neural_style/Conditional-Instance-Norm-for-n-Style-Transfer/vgg_para/'):
        inputs = tf.reverse(inputs, [-1]) - np.array([103.939, 116.779, 123.68])
        para = np.load(vgg_path+"vgg16.npy", encoding="latin1").item()
        F = {}
        inputs = relu(conv_(inputs, para["conv1_1"][0], para["conv1_1"][1]))
        inputs = relu(conv_(inputs, para["conv1_2"][0], para["conv1_2"][1]))
        F["conv1_2"] = inputs
        inputs = max_pooling(inputs)
        inputs = relu(conv_(inputs, para["conv2_1"][0], para["conv2_1"][1]))
        inputs = relu(conv_(inputs, para["conv2_2"][0], para["conv2_2"][1]))
        F["conv2_2"] = inputs
        inputs = max_pooling(inputs)
        inputs = relu(conv_(inputs, para["conv3_1"][0], para["conv3_1"][1]))
        inputs = relu(conv_(inputs, para["conv3_2"][0], para["conv3_2"][1]))
        inputs = relu(conv_(inputs, para["conv3_3"][0], para["conv3_3"][1]))
        F["conv3_3"] = inputs
        inputs = max_pooling(inputs)
        inputs = relu(conv_(inputs, para["conv4_1"][0], para["conv4_1"][1]))
        inputs = relu(conv_(inputs, para["conv4_2"][0], para["conv4_2"][1]))
        inputs = relu(conv_(inputs, para["conv4_3"][0], para["conv4_3"][1]))
        F["conv4_3"] = inputs
        return F
    

    2、风格网络

    def transfer(image):
        conv1 = _conv_layer(image, 32, 9, 1)
        conv2 = _conv_layer(conv1, 64, 3, 2)
        conv3 = _conv_layer(conv2, 128, 3, 2)
        resid1 = _residual_block(conv3, 3)
        resid2 = _residual_block(resid1, 3)
        resid3 = _residual_block(resid2, 3)
        resid4 = _residual_block(resid3, 3)
        resid5 = _residual_block(resid4, 3)
    #     conv_t1 = _conv_tranpose_layer(resid5, 64, 3, 2)
    #     conv_t2 = _conv_tranpose_layer(conv_t1, 32, 3, 2)
        conv_up1 = upsampling(resid5, 64, 3)
        conv_up2 = upsampling(conv_up1, 32, 3)
        conv_up3 = _conv_layer(conv_up2, 3, 9, 1, relu=False)
        preds = tf.nn.sigmoid(conv_up3) * 255.
        return preds
    def _conv_layer(net, num_filters, filter_size, strides, relu=True):
        weights_init = _conv_init_vars(net, num_filters, filter_size)
        strides_shape = [1, strides, strides, 1]
        net = tf.nn.conv2d(net, weights_init, strides_shape, padding='SAME')
        net = _instance_norm(net)
        if relu:
            net = tf.nn.relu(net)
        return net
    def upsampling(net, num_filters, filter_size):
        net = tf.image.resize_nearest_neighbor(net, [tf.shape(net)[1] * 2, tf.shape(net)[2] * 2])
        weights_init = _conv_init_vars(net, num_filters, filter_size)
        net = tf.nn.conv2d(net, weights_init, [1, 1, 1, 1], padding='SAME')
        return _instance_norm(net)
    
    def _residual_block(net, filter_size=3):
        tmp = _conv_layer(net, 128, filter_size, 1)
        return net + _conv_layer(tmp, 128, filter_size, 1, relu=False)
    
    def _instance_norm(net, train=True):
        batch, rows, cols, channels = [i.value for i in net.get_shape()]
        var_shape = [channels]
        mu, sigma_sq = tf.nn.moments(net, [1,2], keep_dims=True)
        shift = tf.Variable(tf.zeros(var_shape))
        scale = tf.Variable(tf.ones(var_shape))
        epsilon = 1e-3
        normalized = (net-mu)/(sigma_sq + epsilon)**(.5)
        return scale * normalized + shift
    
    def _conv_init_vars(net, out_channels, filter_size, transpose=False):
        _, rows, cols, in_channels = [i.value for i in net.get_shape()]
        if not transpose:
            weights_shape = [filter_size, filter_size, in_channels, out_channels]
        else:
            weights_shape = [filter_size, filter_size, out_channels, in_channels]
        weights_init = tf.Variable(tf.truncated_normal(weights_shape, stddev=0.02, seed=1), dtype=tf.float32)
        return weights_init
    

    参考文献

    [1]Perceptual Losses for Real-Time Style Transfer
    and Super-Resolution

    相关文章

      网友评论

        本文标题:风格迁移

        本文链接:https://www.haomeiwen.com/subject/cmhlhctx.html