美文网首页
Android集成TensorFlow使用Mnist数据集实现手

Android集成TensorFlow使用Mnist数据集实现手

作者: 放羊娃华振 | 来源:发表于2020-05-17 13:08 被阅读0次

概述

最想想学习一点Ai相关的东西,所有就简单实现了一个手写数字识别的项目,虽然其中很多的东西不是太明白,还需要自己不断的探索,这里就把目前的所学先记录下来。

Android端的实现

1、集成TensorFlow

网上很多集成TensorFlow的方法很复杂,需要编译源码,其实没有那么复杂,当然你也可以安装哪些步骤把源码下载下来进行编译集成,我是使用简单的集成方式,在Android工程下引入就行,代码如下:

    implementation 'org.tensorflow:tensorflow-android:+'
//  implementation 'org.tensorflow:tensorflow-android:1.13.1'
2、导入跨平台的模型pb文件

这里涉及模型的训练,这个相对来说还是比较复杂的,涉及到文件大小的优化和识别精准度的问题,我目前也训练出了几个模型但是精准度还是没有达到我的预期,但是刚开始学习还是勉强够用了。


image.png
3、实现手写数字控件

这个就是自定义一个控件,在画布上书写数字,再拿到带有数字的bitmap对象。

package com.stormdzh.tfmnist.handwrite;

import android.content.Context;
import android.content.res.Resources;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Paint;
import android.graphics.Path;
import android.util.AttributeSet;
import android.view.MotionEvent;
import android.view.View;

import com.stormdzh.tfmnist.R;

/**
 * @Description: 自定义的view实现手写数字
 * @Author: dzh
 * @CreateDate: 2020-05-15 22:43
 */
public class MyPaintView extends View {
    private Resources myResources;

    // 画笔,定义绘制属性
    private Paint myPaint;
    private Paint mBitmapPaint;

    // 绘制路径
    private Path myPath;

    // 画布及其底层位图
    private Bitmap myBitmap;
    private Canvas myCanvas;

    private float mX, mY;
    private static final float TOUCH_TOLERANCE = 4;

    // 记录宽度和高度
    private int mWidth;
    private int mHeight;

    public MyPaintView(Context context) {
        super(context);
        initialize();
    }

    public MyPaintView(Context context, AttributeSet attrs, int defStyle) {
        super(context, attrs, defStyle);
        initialize();
    }

    public MyPaintView(Context context, AttributeSet attrs) {
        super(context, attrs);
        initialize();
    }

    /**
     * 初始化工作
     */
    private void initialize() {
        myResources = getResources();

        // 绘制自由曲线用的画笔
        myPaint = new Paint();
        myPaint.setAntiAlias(true);
        myPaint.setDither(true);
        myPaint.setColor(myResources.getColor(R.color.white));
        myPaint.setStyle(Paint.Style.STROKE);
        myPaint.setStrokeJoin(Paint.Join.ROUND);
        myPaint.setStrokeCap(Paint.Cap.ROUND);
        myPaint.setStrokeWidth(88);

        myPath = new Path();

        mBitmapPaint = new Paint(Paint.DITHER_FLAG);

    }

    @Override
    protected void onSizeChanged(int w, int h, int oldw, int oldh) {
        super.onSizeChanged(w, h, oldw, oldh);
        mWidth = w;
        mHeight = h;
        myBitmap = Bitmap.createBitmap(w, h, Bitmap.Config.ARGB_8888);
        myBitmap.eraseColor(myResources.getColor(R.color.purple_dark));
        myCanvas = new Canvas(myBitmap);
    }

    @Override
    public boolean onTouchEvent(MotionEvent event) {
        float x = event.getX();
        float y = event.getY();

        switch (event.getAction()) {
            case MotionEvent.ACTION_DOWN:
                touch_start(x, y);
                invalidate();
                break;
            case MotionEvent.ACTION_MOVE:
                touch_move(x, y);
                invalidate();
                break;
            case MotionEvent.ACTION_UP:
                touch_up();
                invalidate();
                break;
        }
        return true;
    }

    @Override
    protected void onDraw(Canvas canvas) {
        super.onDraw(canvas);

        // 如果不调用这个方法,绘制结束后画布将清空
        canvas.drawBitmap(myBitmap, 0, 0, mBitmapPaint);

        // 绘制路径
        canvas.drawPath(myPath, myPaint);

    }

    private void touch_start(float x, float y) {
        myPath.reset();
        myPath.moveTo(x, y);
        mX = x;
        mY = y;
    }

    private void touch_move(float x, float y) {
        float dx = Math.abs(x - mX);
        float dy = Math.abs(y - mY);
        if (dx >= TOUCH_TOLERANCE || dy >= TOUCH_TOLERANCE) {
            myPath.quadTo(mX, mY, (x + mX) / 2, (y + mY) / 2);
            mX = x;
            mY = y;
        }
    }

    private void touch_up() {
        myPath.lineTo(mX, mY);
        // commit the path to our offscreen
        // 如果少了这一句,笔触抬起时myPath重置,那么绘制的线将消失
        myCanvas.drawPath(myPath, myPaint);
        // kill this so we don't double draw
        myPath.reset();
    }

    /**
     * 清除整个图像
     */
    public void clear() {
        // 清除方法1:重新生成位图
        // myBitmap = Bitmap
        // .createBitmap(mWidth, mHeight, Bitmap.Config.ARGB_8888);
        // myCanvas = new Canvas(myBitmap);

        // 清除方法2:将位图清除为白色
        myBitmap.eraseColor(myResources.getColor(R.color.purple_dark));

        // 两种清除方法都必须加上后面这两步:
        // 路径重置
        myPath.reset();
        // 刷新绘制
        invalidate();
    }

    public Bitmap getBitMap() {

        return myBitmap;
    }
}
4、实现布局的编写
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    android:background="#80300900"
    android:gravity="center_horizontal"
    android:orientation="vertical"
    android:paddingLeft="16dp"
    android:paddingTop="16dp"
    android:paddingRight="16dp"
    android:paddingBottom="16dp">

    <TextView
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:layout_gravity="center"
        android:text="点击下面按钮可以实现测试不同数字" />

    <Button
        android:id="@+id/btnTest"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:text="测试" />

    <ImageView
        android:id="@+id/imgPrevieww"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:layout_gravity="center" />

    <TextView
        android:id="@+id/tvResult"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:gravity="center"
        android:text="未知" />

    <com.stormdzh.tfmnist.handwrite.MyPaintView
        android:id="@+id/mMyPaintView"
        android:layout_width="320dp"
        android:layout_height="320dp"
        android:layout_marginTop="10dp"
        android:background="#000000" />

    <LinearLayout
        android:layout_marginTop="10dp"
        android:gravity="center_horizontal"
        android:layout_width="match_parent"
        android:layout_height="40dp"
        android:orientation="horizontal">

        <Button
            android:id="@+id/btnClear"
            android:layout_width="120dp"
            android:layout_height="40dp"
            android:text="清空" />

        <Button
            android:id="@+id/btnOk"
            android:layout_width="120dp"
            android:layout_height="40dp"
            android:text="识别" />
    </LinearLayout>
</LinearLayout>
5、实现一个预测的工具类,调用加载模型和实现预测基本方法
package com.stormdzh.tfmnist;

import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.Color;
import android.graphics.Matrix;
import android.util.Log;

import org.tensorflow.contrib.android.TensorFlowInferenceInterface;

public class PredictionTF {
    private static final String TAG = "PredictionTF";
    //设置模型输入/输出节点的数据维度
    private static final int IN_COL = 1;
    private static final int IN_ROW = 28 * 28;
    private static final int OUT_COL = 1;
    private static final int OUT_ROW = 1;
    //模型中输入变量的名称
//    private static final String inputName = "x_input";
//    private static final String inputName = "regression/Placeholder";
    private static  String inputName = "convolutional/x";
    //模型中输出变量的名称
    private static  String outputName = "output";

    private TensorFlowInferenceInterface inferenceInterface;

    PredictionTF(AssetManager assetManager, String modePath) {
        //初始化TensorFlowInferenceInterface对象
        inferenceInterface = new TensorFlowInferenceInterface(assetManager, modePath);
        Log.e(TAG, "模型文件加载成功");
    }

    /**
     * 利用训练好的TensoFlow模型预测结果
     *
     * @param bitmap 输入被测试的bitmap图
     * @return 返回预测结果,int数组
     */
    public int[] getPredict(Bitmap bitmap) {
        float[] inputdata = bitmapToFloatArray(bitmap, 28, 28);//需要将图片缩放带28*28
        //将数据feed给tensorflow的输入节点
        if(MainActivity.isRegression){
            inputName="regression/Placeholder";
        }else{
            inputName = "convolutional/x";
        }
        inferenceInterface.feed(inputName, inputdata, IN_COL, IN_ROW);
        if(!MainActivity.isRegression) {
            float[] ss = new float[]{0.5f};
            inferenceInterface.feed("convolutional/keep_prob", ss);
        }
        //运行tensorflow
        String[] outputNames = new String[]{outputName};
        inferenceInterface.run(outputNames);
        ///获取输出节点的输出信息
        int[] outputs = new int[OUT_COL * OUT_ROW]; //用于存储模型的输出数据
        inferenceInterface.fetch(outputName, outputs);
        return outputs;
    }

    /**
     * 将bitmap转为(按行优先)一个float数组,并且每个像素点都归一化到0~1之间。
     *
     * @param bitmap 输入被测试的bitmap图片
     * @param rx     将图片缩放到指定的大小(列)->28
     * @param ry     将图片缩放到指定的大小(行)->28
     * @return 返回归一化后的一维float数组 ->28*28
     */
    public static float[] bitmapToFloatArray(Bitmap bitmap, int rx, int ry) {
        int height = bitmap.getHeight();
        int width = bitmap.getWidth();
        // 计算缩放比例
        float scaleWidth = ((float) rx) / width;
        float scaleHeight = ((float) ry) / height;
        Matrix matrix = new Matrix();
        matrix.postScale(scaleWidth, scaleHeight);
        bitmap = Bitmap.createBitmap(bitmap, 0, 0, width, height, matrix, true);
        Log.i(TAG, "bitmap width:" + bitmap.getWidth() + ",height:" + bitmap.getHeight());
        Log.i(TAG, "bitmap.getConfig():" + bitmap.getConfig());
        height = bitmap.getHeight();
        width = bitmap.getWidth();
        float[] result = new float[height * width];
        int k = 0;
        //行优先
        for (int j = 0; j < height; j++) {
            for (int i = 0; i < width; i++) {
                int argb = bitmap.getPixel(i, j);
                int r = Color.red(argb);
                int g = Color.green(argb);
                int b = Color.blue(argb);
                int a = Color.alpha(argb);
                //由于是灰度图,所以r,g,b分量是相等的。
                assert (r == g && g == b);
//                Log.i(TAG,i+","+j+" : argb = "+argb+", a="+a+", r="+r+", g="+g+", b="+b);
                result[k++] = r / 255.0f;
            }
        }
        return result;
    }
}
6、在MainActivity中加载布局和调用预测工具类
package com.stormdzh.tfmnist;

import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.graphics.Matrix;
import android.os.Bundle;
import android.util.Log;
import android.view.View;
import android.widget.ImageView;
import android.widget.TextView;

import androidx.appcompat.app.AppCompatActivity;

import com.stormdzh.tfmnist.handwrite.MyPaintView;

public class MainActivity extends AppCompatActivity implements View.OnClickListener {

    public static boolean isRegression = false; //true 使用线性模型
    private static final String TAG = "MainActivity";
    //    private   String MODEL_FILE = "file:///android_asset/mnist_dzh.pb"; //模型存放路径
//    private   String MODEL_FILE = "file:///android_asset/mnist_regression.pb"; //模型存放路径
    private String MODEL_FILE = "file:///android_asset/mnist_convolutional.pb"; //模型存放路径
    private TextView tvResult;
    private ImageView imgPrevieww;
    private Bitmap bitmap;
    private PredictionTF preTF;
    private int index = 0;
    private MyPaintView mMyPaintView;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        findViewById(R.id.btnTest).setOnClickListener(this);
        findViewById(R.id.btnClear).setOnClickListener(this);
        findViewById(R.id.btnOk).setOnClickListener(this);
        tvResult = (TextView) findViewById(R.id.tvResult);
        imgPrevieww = (ImageView) findViewById(R.id.imgPrevieww);
        mMyPaintView = findViewById(R.id.mMyPaintView);
        getBitmap();
        if (isRegression) {
            MODEL_FILE = "file:///android_asset/mnist_regression.pb"; //模型存放路径
        } else {
            MODEL_FILE = "file:///android_asset/mnist_convolutional.pb"; //模型存放路径
        }
        preTF = new PredictionTF(getAssets(), MODEL_FILE);//输入模型存放路径,并加载TensoFlow模型
    }


    private Bitmap getBitmap() {
        switch (index) {
            case 0:
                bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n0);
                break;
            case 1:
                bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n1);
                break;
            case 2:
                bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n2);
                break;
            case 3:
                bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n3);
                break;
            case 4:
                bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n4);
                break;
            case 5:
                bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n5);
                break;
            case 6:
                bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n6);
                break;
            case 7:
                bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n7);
                break;
            case 8:
                bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n8);
                break;
            case 9:
                bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n9);
                break;
        }
        imgPrevieww.setImageBitmap(bitmap);
        return bitmap;
    }

    @Override
    public void onClick(View view) {
        switch (view.getId()) {
            case R.id.btnTest:
                if (index > 9)
                    index = 0;
                bitmap = getBitmap();
                Log.i(TAG, "sourceBitmap=>" + bitmap.getWidth() + " :" + bitmap.getHeight());
                index++;
                recogBitmap(bitmap);
                break;
            case R.id.btnClear:
                mMyPaintView.clear();
                break;
            case R.id.btnOk:
//                Bitmap viewBitmap = convertViewToBitmap(mMyPaintView);
                Bitmap viewBitmap = mMyPaintView.getBitMap();
//                imgPrevieww.setImageBitmap(viewBitmap);
                Bitmap finalBitmap = scaledBitmap(viewBitmap);
                imgPrevieww.setImageBitmap(finalBitmap);
                Log.i(TAG, "finalBitmap=>" + finalBitmap.getWidth() + " :" + finalBitmap.getHeight());
                recogBitmap(finalBitmap);
                break;
        }
    }

    private Bitmap scaledBitmap(Bitmap viewBitmap) {

        int width = viewBitmap.getWidth();
        float scale = 74f / width;
        Matrix matrix = new Matrix();
        matrix.setScale(scale, scale);
        return Bitmap.createBitmap(viewBitmap, 0, 0, viewBitmap.getWidth(),
                viewBitmap.getHeight(), matrix, true);

    }

    private void recogBitmap(Bitmap bitmap) {
        String res = "图片识别结果为:";
        int[] result = preTF.getPredict(bitmap);
        for (int i = 0; i < result.length; i++) {
            Log.i(TAG, res + result[i]);
            res = res + String.valueOf(result[i]) + " ";
        }
        tvResult.setText(res);
    }

    public Bitmap convertViewToBitmap(View view) {

        view.measure(View.MeasureSpec.makeMeasureSpec(0, View.MeasureSpec.UNSPECIFIED), View.MeasureSpec.makeMeasureSpec(0, View.MeasureSpec.UNSPECIFIED));

        view.layout(0, 0, view.getMeasuredWidth(), view.getMeasuredHeight());

        view.buildDrawingCache();

        Bitmap bitmap = view.getDrawingCache();

        return bitmap;

    }
}

效果

首先看下工程运行后的界面:

image.png
点击测试按钮可以依次循环测试我添加的10中0-9的数字,这个写数字的识别率是100%。
黑色区域是手写区域,有清空和识别两个按钮,清空是清空画布,识别就是开始预测。
例如手写“4”的识别结果:
image.png
目前demo中是使用卷积模型识别的,有些数字的写的歪了等异常情况是识别有误的,这个以后还需要继续优化。代码可以参考我github工程:TFMnist

相关文章

网友评论

      本文标题:Android集成TensorFlow使用Mnist数据集实现手

      本文链接:https://www.haomeiwen.com/subject/egojohtx.html