TensorFlow Android调用

作者: Jcme丶Ls | 来源:发表于2017-07-06 10:58 被阅读470次

    前言

    当我们把使用Python训练的模型固化成PB文件之后,再进行相应的模型压缩之后可以考虑往Mobile端移植了,本文主要讲解TensorFlow Model移植到Android端。

    TensorFlow1.0之后推出了Java版本,所以间接为Android开发TensorFlow程序带来便利,以前我们需要用JNI去编写,可是JNI难于调试,C++代码对于普通Android开发者来讲还是比Java繁琐,所以本文以Java API讲述开发过程。

    正文

    下面就正式开始一直TensorFlow model到Android中啦。

    • 引入依赖

    在TensorFlow更新到1.2.0版本之后,TensorFlow为广大开发者提供了gradle依赖,现在我们想要引入TensorFlow只需要在gradle中加入

    compile 'org.tensorflow:tensorflow-android:1.2.0-rc0'
    

    即可引入TensorFlow的库。

    • 复制PB文件

    快速开发的话直接把PB文件放在assets文件夹里就行,如果正式上线的时候觉得PB文件一起打包较大的话可以放在服务器,打开APP的时候提示下载再复制进去就好。

    • 创建TensorFlowInterface类

    这个类指的是我们读取、识别等一系列方法存放的类,名字随你取。

    • 载入TensorFlow

    在类的第一行加入这句话,会在加载类的时候首先加载TensorFlow

        {
            System.loadLibrary("tensorflow_inference");
        }
    
    • 定义常量

    在这一步,我们先定义一些常量,比如输入节点名、输出节点名、输出图像的尺寸、通道、输入节点数据类型、输出节点数据类型。代码如下

        private static final String input_layer = "inputs/X";
        private static final String output_layer = "output/predict";
    
        private Context context;
        private static final int HEIGHT = 64;
        private static final int WIDTH = 256;
        private static final int CHANNEL = 1;
    
        private float[] inputs = new float[HEIGHT*WIDTH*CHANNEL];
        private long[] outputs = new long[11];
    
    • 初始化模型

    这一步TensorFlow的模型会载入到内存中,传入assets和PB文件名

    TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface(context.getAssets(),"rounded_graph.pb"); 
    
    • 喂数据给输入节点

    这里的参数是输入节点名,输入数据,输入数据的shape

    inferenceInterface.feed(input_layer,inputs,1,16384);
    
    • run session
    inferenceInterface.run(new String[] { output_layer }, false);
    
    • 获取输出数据

    根据你在Python定义的输出格式,new一个接收输出数据的变量,从输出节点获取数据

    byte[] outPuts = new byte[88];
    inferenceInterface.fetch(output_layer,outPuts);
    
    • 数据变换

    从输出节点获取到数据之后就需要你对自己的输出数据进行操作,比如我在我们model里最终输出的结果进行了Argmax的操作,Argmax返回的值类型是Int64的,在Android里只有long对应,但fetch方法的接受变量的参数类型只有double、float、int、byte,所以这里需要使用byte获取,再进行转换。这里跟传统的byte[8]转long有些不同,具体处理方式要看你定义的数据格式,我这里的byte[8]用网上的方法转long发现数值非常大,于是遍历一遍byte[8],发现每个子元素都是相同的数值,所以这里只取第一个元素,组成一个新的数组,再对这个数组进行解析。

    long[] tOutputs=new long[11];
    for (int i=0;i<11;i++)
    {
        int k=i*8;
        tOutputs[i]=outPuts[k];
        Log.i("output",tOutputs[i]+"");
    }
    String outputStr="";
    for(int i=0;i<11;i++){
        long char_idx=tOutputs[i];
        long char_code = 0;
        if (char_idx<10){
            char_code = char_idx + (int)('0');
        }
        else if (char_idx<36){
            char_code = char_idx-10 + (int)('A');
        }
        else if (char_idx<62){
            char_code = char_idx + (int)('a');
        }
        outputStr+= (char)char_code;
    }
    

    后记

    有Java API确实相比C++来的更直观方便,而且native debug也比JNI好操作,等TensorFlowLite出来的时候,Android TensorFlow应用会更加广泛吧。

    相关文章

      网友评论

        本文标题:TensorFlow Android调用

        本文链接:https://www.haomeiwen.com/subject/udochxtx.html