iOS实现类Prisma软件

作者: Jiao123 | 来源:发表于2017-04-25 18:11 被阅读980次

    前言


    Prisma在2016上线后就大火,该APP是利用神经网络人工智能技术,为普通照片加入艺术效果的照片编辑软件。

    同年Google也发布了一篇《A LEARNED REPRESENTATION FOR ARTISTIC STYLE》论文,实现了前向运算一次为照片整合多种艺术风格的功能,并且优化了内存使用和运算速度,可以在移动设备上快速运算。

    最近在研究Tensorflow整合iOS过程中,发现google公开了论文实现的源码和训练数据,也就是说我们可以通过自己写一个前向运算图,整合其训练参数就可以快速实现类Prisma的应用。

    下面就介绍一下如何在iPhone上跑一个自己的"Prisma"

    招财和咕噜

    准备工作


    1. 安装Tensorflow,这个官网上有详细教程这里就不多说了。
    2. 搭建iOS+Tensorflow工程,这个可以根据Git上的步骤实现,也可以参考官方的Demo程序配置。(这个过程有很多坑,多次尝试,应该可以配置成功)
    3. 下载模型,本次使用的模型是image_stylization,google已开源在GitHub上。
    4. 下载训练好的参数,Google提供了2个:
      Monet
      Varied
      Monet训练了10种艺术图片,Varied训练了32种。
      当然你也可以自己训练艺术图片,但是得下载VGG的训练参数和ImageNet数据,然后自己训练,比较花时间。

    构建计算图


    虽然Google提供了模型的源码,但是并没有在源码中输出运算图已方便迁移到移动设备中使用,Android的Demo中倒是提供了生成的pb,如何觉得自己写计算图麻烦可以直接拷到自己iOS工程中使用。

    我这里创建了一个python的工程,然后把Google源码中model.py相关的文件都加入了工程。
    我的建图代码如下:

    import numpy as np
    import tensorflow as tf
    import ast
    import os
    from tensorflow.python import pywrap_tensorflow
    
    from matplotlib import pyplot
    from matplotlib.pyplot import imshow
    
    import image_utils
    import model
    import ops
    import argparse
    import sys
    
    
    num_styles = 32
    imgWidth = 512
    imgHeight = 512
    channel = 3
    checkpoint = "/Users/Jiao/Desktop/TFProject/style-image/checkpoint/multistyle-pastiche-generator-varied.ckpt"
    
    inputImage = tf.placeholder(tf.float32,shape=[None,imgWidth,imgHeight,channel],name="input")
    styles = tf.placeholder(tf.float32,shape=[num_styles],name="style")
    
    with tf.name_scope(""):
        transform = model.transform(inputImage,
                                normalizer_fn=ops.weighted_instance_norm,
                                normalizer_params={
                                    # 'weights': tf.constant(mixture),
                                    'weights' : styles,
                                    'num_categories': num_styles,
                                    'center': True,
                                    'scale': True})
    
    model_saver = tf.train.Saver(tf.global_variables())
    
    with tf.Session() as sess:
        tf.train.write_graph(sess.graph_def, "/Users/Jiao/Desktop/TFProject/style-image/protobuf", "input.pb")
        #checkpoint = os.path.expanduser(checkpoint)
        #if tf.gfile.IsDirectory(checkpoint):
        #    checkpoint = tf.train.latest_checkpoint(checkpoint)
        #    tf.logging.info('loading latest checkpoint file: {}'.format(checkpoint))
        #model_saver.restore(sess, checkpoint)
    
        #newstyle = np.zeros([num_styles], dtype=np.float32)
        #newstyle[18] = 0.5
        #newstyle[17] = 0.5
        #newImage = np.zeros((1,imgWidth,imgHeight,channel))
        #style_image = transform.eval(feed_dict={inputImage:newImage,styles:newstyle})
        #style_image = style_image[0]
        #imshow(style_image)
        #pyplot.show()
    

    这里输入节点是inputstyle,输出节点是model中的transformer/expand/conv3/conv/Sigmoid

    到此就将模型的计算图保存到了本地文件夹中。
    接下来就是将图和ckpt中的参数合并,并且生成移动端的可以使用的pb文件,这一步可以参考我上一篇文章《iOS+Tensorflow实现图像识别》,很容易就实现。

    iOS工程


    在上面准备工作中,如果你已经按步骤搭建好iOS+TF的工程,这里你只需要导入生成的最终pb文件就行了。工程结构如图:

    XCode工程

    然后在iOS使用pb文件,我这里直接导入了Google提供的tensorflow_utils,使用这个类里面的LoadModel方法可以很快的生成含有计算图的session。

    - (void)viewDidLoad {
        [super viewDidLoad];
        tensorflow::Status load_status;
        load_status = LoadModel(@"rounded_graph", @"pb", &tf_session);
        if (!load_status.ok()) {
            LOG(FATAL) << "Couldn't load model: " << load_status;
        }
        currentStyle = 0;
        isDone = true;
        _styleImageView.layer.borderColor = [UIColor grayColor].CGColor;
        _styleImageView.layer.borderWidth = 0.5;
        _ogImageView.layer.borderColor = [UIColor grayColor].CGColor;
        _ogImageView.layer.borderWidth = 0.5;
    }
    

    最后就是获取图片,执行运算,生成艺术图片展示。这里图片需要转换成bitmap然后获取data值,展示图片也是相识的过程。具体代码如下:

    - (void)runCnn:(UIImage *)compressedImg
    {
        unsigned char *pixels = [self getImagePixel:compressedImg];
        int image_channels = 4;
        tensorflow::Tensor image_tensor(
                                        tensorflow::DT_FLOAT,
                                        tensorflow::TensorShape(
                                                                {1, wanted_input_height, wanted_input_width, wanted_input_channels}));
        auto image_tensor_mapped = image_tensor.tensor<float, 4>();
        tensorflow::uint8 *in = pixels;
        float *out = image_tensor_mapped.data();
        for (int y = 0; y < wanted_input_height; ++y) {
            float *out_row = out + (y * wanted_input_width * wanted_input_channels);
            for (int x = 0; x < wanted_input_width; ++x) {
                tensorflow::uint8 *in_pixel =
                in + (x * wanted_input_width * image_channels) + (y * image_channels);
                float *out_pixel = out_row + (x * wanted_input_channels);
                for (int c = 0; c < wanted_input_channels; ++c) {
                    out_pixel[c] = in_pixel[c];
                }
            }
        }
        
        
        tensorflow::Tensor style(tensorflow::DT_FLOAT, tensorflow::TensorShape({32}));
        float *style_data = style.tensor<float, 1>().data();
        memset(style_data, 0, sizeof(float) * 32);
        style_data[currentStyle] = 1;
        
        if (tf_session.get()) {
            std::vector<tensorflow::Tensor> outputs;
            tensorflow::Status run_status = tf_session->Run(
                                                            {{contentNode, image_tensor},
                                                                {styleNode, style}},
                                                            {outputNode},
                                                            {},
                                                            &outputs);
            if (!run_status.ok()) {
                LOG(ERROR) << "Running model failed:" << run_status;
                isDone = true;
                free(pixels);
            } else {
                float *styledData = outputs[0].tensor<float,4>().data();
                UIImage *styledImg = [self createImage:styledData];
                dispatch_async(dispatch_get_main_queue(), ^{
                    _styleImageView.image = styledImg;
                    dispatch_after(dispatch_time(DISPATCH_TIME_NOW, (int64_t)(0.3 * NSEC_PER_SEC)), dispatch_get_main_queue(), ^{
                        isDone = true;
                        free(pixels);
                    });
                });
            }
        }
    }
    
    - (unsigned char *)getImagePixel:(UIImage *)image
    {
        int width = image.size.width;
        int height = image.size.height;
        CGColorSpaceRef colorSpace = CGColorSpaceCreateDeviceRGB();
        unsigned char *rawData = (unsigned char*) calloc(height * width * 4, sizeof(unsigned char));
        NSUInteger bytesPerPixel = 4;
        NSUInteger bytesPerRow = bytesPerPixel * width;
        NSUInteger bitsPerComponent = 8;
        CGContextRef context = CGBitmapContextCreate(rawData, width, height,
                                                     
                                                     bitsPerComponent, bytesPerRow, colorSpace,
                                                     
                                                     kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big);
        
        CGColorSpaceRelease(colorSpace);
        CGContextDrawImage(context, CGRectMake(0, 0, width, height), image.CGImage);
        UIImage *ogImg = [UIImage imageWithCGImage:CGBitmapContextCreateImage(context)];
        dispatch_async(dispatch_get_main_queue(), ^{
            _ogImageView.image = ogImg;
        });
        CGContextRelease(context);
        return rawData;
    }
    
    - (UIImage *)createImage:(float *)pixels
    {
        unsigned char *rawData = (unsigned char*) calloc(wanted_input_height * wanted_input_width * 4, sizeof(unsigned char));
        for (int y = 0; y < wanted_input_height; ++y) {
            unsigned char *out_row = rawData + (y * wanted_input_width * 4);
            for (int x = 0; x < wanted_input_width; ++x) {
                float *in_pixel =
                pixels + (x * wanted_input_width * 3) + (y * 3);
                unsigned char *out_pixel = out_row + (x * 4);
                for (int c = 0; c < wanted_input_channels; ++c) {
                    out_pixel[c] = in_pixel[c] * 255;
                }
                out_pixel[3] = UINT8_MAX;
            }
        }
        CGColorSpaceRef colorSpace = CGColorSpaceCreateDeviceRGB();
        NSUInteger bytesPerPixel = 4;
        NSUInteger bytesPerRow = bytesPerPixel * wanted_input_width;
        NSUInteger bitsPerComponent = 8;
        CGContextRef context = CGBitmapContextCreate(rawData, wanted_input_width, wanted_input_height,
                                                     
                                                     bitsPerComponent, bytesPerRow, colorSpace,
                                                     
                                                     kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big);
        
        CGColorSpaceRelease(colorSpace);
        UIImage *retImg = [UIImage imageWithCGImage:CGBitmapContextCreateImage(context)];
        CGContextRelease(context);
        free(rawData);
        return retImg;
    }
    

    这里说明一下,前面python工程已经定义了,我的输入和输出图片的大小是512✕512。

    连接iPhone,运行工程_


    最后连上手机运行,就可以自己创建自己的艺术类图片了。😊

    放几张运行效果图:


    截图1 截图2 截图3

    相关文章

      网友评论

      • ff0f6b05526c:正在学习,遇到了一些困难,不知道是否能发一份源码参考一下,谢谢!wxhsui@qq.com
        Jiao123:@若风者 已发
      • wingsmm:wingsmm@me.com 谢谢大牛
      • summer_shi:感谢大牛,源码能发一份参考吗?shiwanjun@126.com,谢谢!
        summer_shi:@Jiao123 好的,已经收到邮件,谢谢@Jiao123
        Jiao123:@summer_shi 已发送
      • 徐涌盛:作者大大有github链接么
        Jiao123:@徐涌盛 没有放github上
      • symbian_0:最近在研究这个,希望作者看到能给我发一份demo
        taowei728@126.com
        summer_shi:@symbian_0 能转发一份给我吗? shiwanjun@126.com,谢谢!
        symbian_0:@Jiao123 谢谢
        Jiao123:@symbian_0 已发送
      • 只因我为足球而生:大神,跪求个demo可以吗?
        Jiao123:@只因我为足球而生 已发送
        只因我为足球而生:@Jiao123 215742216@qq.com
        Jiao123:@只因我为足球而生 可以,邮件发给你
      • 毛尖尖:大神,跪求demo monazh@126.com
        毛尖尖:@Jiao123 :kissing_heart:谢谢
        Jiao123:@张3_ 已发
      • a8453e425101:求大牛发一份代码,edcrfvedcrfv2@sina.com
        a8453e425101:@Jiao123 我使用下载的ckpt生成的pb文件使用也有问题,帖主有做什么特殊处理吗?上面的英文的解决方式没太看懂:“There are some Nans in the ckpt you downloaded from the wed...”
        a8453e425101:@Jiao123 非常感谢
        Jiao123:@edcrfvedcrfv2 已发
      • 不想重复造轮子:大牛也可以发一份demo给我吗 547172058@qq.com
        Jiao123:@死心 那个只是个转化工具,模型都得自己写。
        不想重复造轮子:@Jiao123 tensorflow_utils 这个文件是根据不同类别的模型来的吗? 还是说有方法能根据模型来生成调用的c方法的文件 ?
        Jiao123:@死心 已发
      • rogerwu1228:你好, 同学, 能分享一份源码吗, 邮箱: why_404@126.com 谢谢啦 :smile:
        rogerwu1228:@Jiao123 Thx...
        Jiao123:已发送
      • 0a932f82c896:你好, 能分享一下代码吗。 邮箱:71902816@qq.com
        rogerwu1228:你好, 同学, 能分享一份源码吗, 邮箱: why_404@126.com 谢谢啦 :smile:
        0a932f82c896:@Jiao123 感谢:pray:
        Jiao123:已发
      • 0408d074cf8e:膜拜大神,麻烦能发一份本工程的源码吗?我的邮箱是liyangsh48@gmail.com,提前谢过!!
        0408d074cf8e:@Jiao123 非常感谢
        Jiao123:已发
      • K_Gopher:感谢大神,可否共享您的demo源码,我的邮箱是454694347@qq.com
        K_Gopher:@Jiao123 thx
        Jiao123:已发
      • binshadow:大神,求个demo。谢谢分享
        8929317@qq.com
        Jiao123:已发
      • 木木不哭:求个demo 1010774511@qq.com ,谢啦
        Jiao123: 已发至邮箱
      • xieyingze:求demo,谢谢楼主13424485402@163.com
        Jiao123:已发至邮箱
      • 土匪小勇:赞一个,最近公司项目需要用到这块。 希望作者看到能联系我哈下,感谢
        david@gpower.co 手机 18991902916
        Jiao123:好的
      • ifelsego:非常感谢分享。按照文章做了一下,但是在执行Run的时候会crash掉。可以分享一下你的demo吗?
        0408d074cf8e:@ifelsego 同学,能给我的邮箱发一份吗?liyangsh48@gmail.com,谢谢!
        ifelsego:@Jiao123 多谢多谢!
        Jiao123:已发至你邮箱
      • fc4ca578e197:您好,谢谢您的分享。我按照您的方法转ckpt的模型为pb,然后应用在Android上,但是并不能使用,这是我的issue https://github.com/tensorflow/tensorflow/issues/9678
        请问您可以分享您demo和转换后的model吗。谢谢。
        rogerwu1228:@Jiao123 博主, 我也遇到了楼上的问题, 生成的pb文件, 不能正常使用,然后我看到在github上说到: "There are some Nans in the ckpt you downloaded from the wed.
        You should get rid of them. Then you will get right outputs on the mobile device." 这个具体是对 ckpt 进行什么操作呢...
        Jiao123:可以
      • c0279c31eb49:大神, 求个demo
        Jiao123:可以私信我邮箱,我分享一份给你。
      • Bert夏伤:感谢大牛,你的demo源码可以下载吗?
        Jiao123:@Bert夏伤 已发邮箱
        Bert夏伤:@Jiao123 zq529972839@qq.com 谢谢啦!:+1:
        Jiao123:代码暂时没托管,你可以私信我你的邮箱,然后我发给你

      本文标题:iOS实现类Prisma软件

      本文链接:https://www.haomeiwen.com/subject/rygvzttx.html