美文网首页Android常用功能
Android——Tensorflow-Lite简单使用

Android——Tensorflow-Lite简单使用

作者: 海晨忆 | 来源:发表于2019-05-24 11:41 被阅读1次

    个人博客:haichenyi.com。感谢关注

      项目里面用到了tflite,用于做简单的图片处理,不是判断图片是什么类型,就是传进去图片,生成新图片,类似于前面一篇讲的GPUImage的滤镜功能,但是比滤镜功能更加强大。

      我这里要做的就是集成,拿人家训练好的模型直接来用,我不用去训练模型。

    第一步 依赖

    //依赖库
    implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
    
    
    android {
        ···
        //set no compress models
        aaptOptions {
            noCompress "tflite"
        }
    }
    

    第二步 加载训练模型

      网上很多介绍资料都是把训练模型直接copy到项目main目录下的assets目录(不存在就创建)与java目录平级,自然,这样的加载方式就是

    // load infer model
        private void loadModel(String model) {
            try {
                tflite = new Interpreter(loadModelFile(model));
                Log.d(TAG, model + " model load success");
                tflite.setNumThreads(4);
                load_result = true;
            } catch (IOException e) {
                Log.d(TAG, model + " model load fail");
                load_result = false;
                e.printStackTrace();
            }
        }
        
        
        /**
         * Memory-map the model file in Assets.
         */
        private MappedByteBuffer loadModelFile(String model) throws IOException {
            AssetFileDescriptor fileDescriptor = getApplicationContext().getAssets().openFd(model + ".tflite");
            FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
            FileChannel fileChannel = inputStream.getChannel();
            long startOffset = fileDescriptor.getStartOffset();
            long declaredLength = fileDescriptor.getDeclaredLength();
            return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
        }
    

      一个tflite文件就好几M,甚至十几M,全部copy到项目里面不显示,所以,我们一般项目里面用都是先下载,然后再使用,那,这样的方式,我们要怎么加载训练模型呢?

      我们先分析一下再assets目录下面怎么加载的?说白了就是新建一个Interpreter对象,就是加载模型。上面的方法都过时了,我们可以找到Interpreter类,里面你会看到如下的方法

    //第一个参数传tflite文件,第二个参数传一个Interpreter静态内部类对象
    public Interpreter(@NonNull File modelFile, Interpreter.Options options) {
            this.wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), options);
    }
        
    //所以,我们自己项目里面加载模型,用如下方式即可
    Interpreter.Options options = new Interpreter.Options();
    options.setNumThreads(4);
    tflite = new Interpreter(new File(""), options);
    

    第三步 执行run方法

    tflite.run(in, out);
    

      通过执行这个run方法,获取我们需要的东西,第一个参数,输入对象,第二个参数,输出参数。

    重点,敲黑板

    重点,敲黑板

    重点,敲黑板

      重点就在这里,这里的输入和输出参数要怎么传?我这里训练模型是用Python做的,它需要传入一个四维数组,所以,输出我们自然也要用一个四维数组接收。

      这里的四维数组怎么传递呐?就要说到Android里面的bitmap知识了,它的每个像素点都是一个ARGB数组。即透明度,红色,绿色,蓝色。我们前面的灰色滤镜之类的东西,实际上就是改变RGB三原色的值,让颜色变成灰色,然后改变亮度之类的就是改变每个管道的透明度。网上有很多这样的知识。

      再来说说这个四维数组,我项目里面用到的这个四维数组:1 X 256 X 256 X 3,这几个值怎么理解呢?

    1:表示一张图片
    
    256X256:表示图片的宽高
    
    3:表示RGB色值
    

      那我们怎么把bitmap对象,转换成我们需要的四维数组呐?

    //定义了一个一维数组,里面就是我们需要的参数,便于修改
    private int[] ddims = {1, 256, 256, 3};
    
        /**
         * 获取图片的四维数组
         * @param bitmap bitmap对象
         * @param ddims 参数数组
         * @return 图片四维数组
         */
    public float[][][][] getScaledMatrix(Bitmap bitmap, int[] ddims) {
            //新建一个1*256*256*3的四维数组
            float[][][][] inFloat = new float[ddims[0]][ddims[1]][ddims[2]][ddims[3]];
            //新建一个一维数组,长度是图片像素点的数量
            int[] pixels = new int[ddims[1] * ddims[2]];
            //把原图缩放成我们需要的图片大小
            Bitmap bm = Bitmap.createScaledBitmap(bitmap, ddims[1], ddims[2], false);
            //把图片的每个像素点的值放到我们前面新建的一维数组中
            bm.getPixels(pixels, 0, bm.getWidth(), 0, 0, ddims[1], ddims[2]);
            int pixel = 0;
            //for循环,把每个像素点的值转换成RBG的值,存放到我们的目标数组中
            for (int i = 0; i < ddims[1]; ++i) {
                for (int j = 0; j < ddims[2]; ++j) {
                    final int val = pixels[pixel++];
                    float red = ((val >> 16) & 0xFF);
                    float green = ((val >> 8) & 0xFF);
                    float blue = (val & 0xFF);
                    float[] arr = {red, green, blue};
                    inFloat[0][i][j] = arr;
                }
            }
            if (bm.isRecycled()) {
                bm.recycle();
            }
            return inFloat;
        }
    

      上面代码注释写的很清楚了吧?每一行都有注释,for循环的作用也标的很清楚,通过这个方法,我们得到的就是我们想要的四维数组了,这里的四维数组的格式,图片的大小,都是tflite文件建模型的时候设置好的,看你们训练模型的工程师是怎么定义的,你就怎么传。

      然后,新建一个一模一样格式的数组去接收输出值,也是一个四维数组,那么,我们怎么把这个四维数组转换成我们需要的bitmap呢?

    //创建bitmap的方法,
    Bitmap.createBitmap(@NonNull @ColorInt int[] colors,
                int width, int height, Config config);
    

      就是这个方法,传一个一维颜色数组,图片的宽高,还有一个图片的格式,那我们这里就是要把这个四维数组转成一个一维的颜色数组了。

        /**
         * 四维数组转成bitmap对象
         * @param outArr 数组
         * @param ddims 格式
         * @return bitmap
         */
        public Bitmap getBitmap(float[][][][] outArr, int[] ddims) {
            //获取图片的三维数组
            float[][][] temp = outArr[0];
            int n = 0;
            //新建一个接收的颜色数组,长度就是图片的宽高之积,类似于上面的像素那个数组
            int[] colorArr = new int[ddims[1] * ddims[2]];
            //for循环遍历把图片的ARGB色值转成一个颜色值,放入颜色数组中
            for (int i = 0; i < ddims[1]; i++) {
                for (int j = 0; j < ddims[2]; j++) {
                    float[] arr = temp[i][j];
                    int alpha = 255;
                    int red = (int) arr[0];
                    int green = (int) arr[1];
                    int blue = (int) arr[2];
                    int tempARGB = (alpha << 24) | (red << 16) | (green << 8) | blue;
                    colorArr[n++] = tempARGB;
                }
            }
            //创建bitmap对象
            return Bitmap.createBitmap(colorArr, ddims[1], ddims[2], Bitmap.Config.ARGB_8888);
        }
    

      至此,我们就拿到了,我们需要的bitmap对象了,然后再做后续的逻辑即可。

    项目链接

    相关文章

      网友评论

        本文标题:Android——Tensorflow-Lite简单使用

        本文链接:https://www.haomeiwen.com/subject/byfgzqtx.html