概述
最想想学习一点Ai相关的东西,所有就简单实现了一个手写数字识别的项目,虽然其中很多的东西不是太明白,还需要自己不断的探索,这里就把目前的所学先记录下来。
Android端的实现
1、集成TensorFlow
网上很多集成TensorFlow的方法很复杂,需要编译源码,其实没有那么复杂,当然你也可以安装哪些步骤把源码下载下来进行编译集成,我是使用简单的集成方式,在Android工程下引入就行,代码如下:
implementation 'org.tensorflow:tensorflow-android:+'
// implementation 'org.tensorflow:tensorflow-android:1.13.1'
2、导入跨平台的模型pb文件
这里涉及模型的训练,这个相对来说还是比较复杂的,涉及到文件大小的优化和识别精准度的问题,我目前也训练出了几个模型但是精准度还是没有达到我的预期,但是刚开始学习还是勉强够用了。

3、实现手写数字控件
这个就是自定义一个控件,在画布上书写数字,再拿到带有数字的bitmap对象。
package com.stormdzh.tfmnist.handwrite;
import android.content.Context;
import android.content.res.Resources;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Paint;
import android.graphics.Path;
import android.util.AttributeSet;
import android.view.MotionEvent;
import android.view.View;
import com.stormdzh.tfmnist.R;
/**
* @Description: 自定义的view实现手写数字
* @Author: dzh
* @CreateDate: 2020-05-15 22:43
*/
public class MyPaintView extends View {
private Resources myResources;
// 画笔,定义绘制属性
private Paint myPaint;
private Paint mBitmapPaint;
// 绘制路径
private Path myPath;
// 画布及其底层位图
private Bitmap myBitmap;
private Canvas myCanvas;
private float mX, mY;
private static final float TOUCH_TOLERANCE = 4;
// 记录宽度和高度
private int mWidth;
private int mHeight;
public MyPaintView(Context context) {
super(context);
initialize();
}
public MyPaintView(Context context, AttributeSet attrs, int defStyle) {
super(context, attrs, defStyle);
initialize();
}
public MyPaintView(Context context, AttributeSet attrs) {
super(context, attrs);
initialize();
}
/**
* 初始化工作
*/
private void initialize() {
myResources = getResources();
// 绘制自由曲线用的画笔
myPaint = new Paint();
myPaint.setAntiAlias(true);
myPaint.setDither(true);
myPaint.setColor(myResources.getColor(R.color.white));
myPaint.setStyle(Paint.Style.STROKE);
myPaint.setStrokeJoin(Paint.Join.ROUND);
myPaint.setStrokeCap(Paint.Cap.ROUND);
myPaint.setStrokeWidth(88);
myPath = new Path();
mBitmapPaint = new Paint(Paint.DITHER_FLAG);
}
@Override
protected void onSizeChanged(int w, int h, int oldw, int oldh) {
super.onSizeChanged(w, h, oldw, oldh);
mWidth = w;
mHeight = h;
myBitmap = Bitmap.createBitmap(w, h, Bitmap.Config.ARGB_8888);
myBitmap.eraseColor(myResources.getColor(R.color.purple_dark));
myCanvas = new Canvas(myBitmap);
}
@Override
public boolean onTouchEvent(MotionEvent event) {
float x = event.getX();
float y = event.getY();
switch (event.getAction()) {
case MotionEvent.ACTION_DOWN:
touch_start(x, y);
invalidate();
break;
case MotionEvent.ACTION_MOVE:
touch_move(x, y);
invalidate();
break;
case MotionEvent.ACTION_UP:
touch_up();
invalidate();
break;
}
return true;
}
@Override
protected void onDraw(Canvas canvas) {
super.onDraw(canvas);
// 如果不调用这个方法,绘制结束后画布将清空
canvas.drawBitmap(myBitmap, 0, 0, mBitmapPaint);
// 绘制路径
canvas.drawPath(myPath, myPaint);
}
private void touch_start(float x, float y) {
myPath.reset();
myPath.moveTo(x, y);
mX = x;
mY = y;
}
private void touch_move(float x, float y) {
float dx = Math.abs(x - mX);
float dy = Math.abs(y - mY);
if (dx >= TOUCH_TOLERANCE || dy >= TOUCH_TOLERANCE) {
myPath.quadTo(mX, mY, (x + mX) / 2, (y + mY) / 2);
mX = x;
mY = y;
}
}
private void touch_up() {
myPath.lineTo(mX, mY);
// commit the path to our offscreen
// 如果少了这一句,笔触抬起时myPath重置,那么绘制的线将消失
myCanvas.drawPath(myPath, myPaint);
// kill this so we don't double draw
myPath.reset();
}
/**
* 清除整个图像
*/
public void clear() {
// 清除方法1:重新生成位图
// myBitmap = Bitmap
// .createBitmap(mWidth, mHeight, Bitmap.Config.ARGB_8888);
// myCanvas = new Canvas(myBitmap);
// 清除方法2:将位图清除为白色
myBitmap.eraseColor(myResources.getColor(R.color.purple_dark));
// 两种清除方法都必须加上后面这两步:
// 路径重置
myPath.reset();
// 刷新绘制
invalidate();
}
public Bitmap getBitMap() {
return myBitmap;
}
}
4、实现布局的编写
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:background="#80300900"
android:gravity="center_horizontal"
android:orientation="vertical"
android:paddingLeft="16dp"
android:paddingTop="16dp"
android:paddingRight="16dp"
android:paddingBottom="16dp">
<TextView
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_gravity="center"
android:text="点击下面按钮可以实现测试不同数字" />
<Button
android:id="@+id/btnTest"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:text="测试" />
<ImageView
android:id="@+id/imgPrevieww"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_gravity="center" />
<TextView
android:id="@+id/tvResult"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:gravity="center"
android:text="未知" />
<com.stormdzh.tfmnist.handwrite.MyPaintView
android:id="@+id/mMyPaintView"
android:layout_width="320dp"
android:layout_height="320dp"
android:layout_marginTop="10dp"
android:background="#000000" />
<LinearLayout
android:layout_marginTop="10dp"
android:gravity="center_horizontal"
android:layout_width="match_parent"
android:layout_height="40dp"
android:orientation="horizontal">
<Button
android:id="@+id/btnClear"
android:layout_width="120dp"
android:layout_height="40dp"
android:text="清空" />
<Button
android:id="@+id/btnOk"
android:layout_width="120dp"
android:layout_height="40dp"
android:text="识别" />
</LinearLayout>
</LinearLayout>
5、实现一个预测的工具类,调用加载模型和实现预测基本方法
package com.stormdzh.tfmnist;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.Color;
import android.graphics.Matrix;
import android.util.Log;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
public class PredictionTF {
private static final String TAG = "PredictionTF";
//设置模型输入/输出节点的数据维度
private static final int IN_COL = 1;
private static final int IN_ROW = 28 * 28;
private static final int OUT_COL = 1;
private static final int OUT_ROW = 1;
//模型中输入变量的名称
// private static final String inputName = "x_input";
// private static final String inputName = "regression/Placeholder";
private static String inputName = "convolutional/x";
//模型中输出变量的名称
private static String outputName = "output";
private TensorFlowInferenceInterface inferenceInterface;
PredictionTF(AssetManager assetManager, String modePath) {
//初始化TensorFlowInferenceInterface对象
inferenceInterface = new TensorFlowInferenceInterface(assetManager, modePath);
Log.e(TAG, "模型文件加载成功");
}
/**
* 利用训练好的TensoFlow模型预测结果
*
* @param bitmap 输入被测试的bitmap图
* @return 返回预测结果,int数组
*/
public int[] getPredict(Bitmap bitmap) {
float[] inputdata = bitmapToFloatArray(bitmap, 28, 28);//需要将图片缩放带28*28
//将数据feed给tensorflow的输入节点
if(MainActivity.isRegression){
inputName="regression/Placeholder";
}else{
inputName = "convolutional/x";
}
inferenceInterface.feed(inputName, inputdata, IN_COL, IN_ROW);
if(!MainActivity.isRegression) {
float[] ss = new float[]{0.5f};
inferenceInterface.feed("convolutional/keep_prob", ss);
}
//运行tensorflow
String[] outputNames = new String[]{outputName};
inferenceInterface.run(outputNames);
///获取输出节点的输出信息
int[] outputs = new int[OUT_COL * OUT_ROW]; //用于存储模型的输出数据
inferenceInterface.fetch(outputName, outputs);
return outputs;
}
/**
* 将bitmap转为(按行优先)一个float数组,并且每个像素点都归一化到0~1之间。
*
* @param bitmap 输入被测试的bitmap图片
* @param rx 将图片缩放到指定的大小(列)->28
* @param ry 将图片缩放到指定的大小(行)->28
* @return 返回归一化后的一维float数组 ->28*28
*/
public static float[] bitmapToFloatArray(Bitmap bitmap, int rx, int ry) {
int height = bitmap.getHeight();
int width = bitmap.getWidth();
// 计算缩放比例
float scaleWidth = ((float) rx) / width;
float scaleHeight = ((float) ry) / height;
Matrix matrix = new Matrix();
matrix.postScale(scaleWidth, scaleHeight);
bitmap = Bitmap.createBitmap(bitmap, 0, 0, width, height, matrix, true);
Log.i(TAG, "bitmap width:" + bitmap.getWidth() + ",height:" + bitmap.getHeight());
Log.i(TAG, "bitmap.getConfig():" + bitmap.getConfig());
height = bitmap.getHeight();
width = bitmap.getWidth();
float[] result = new float[height * width];
int k = 0;
//行优先
for (int j = 0; j < height; j++) {
for (int i = 0; i < width; i++) {
int argb = bitmap.getPixel(i, j);
int r = Color.red(argb);
int g = Color.green(argb);
int b = Color.blue(argb);
int a = Color.alpha(argb);
//由于是灰度图,所以r,g,b分量是相等的。
assert (r == g && g == b);
// Log.i(TAG,i+","+j+" : argb = "+argb+", a="+a+", r="+r+", g="+g+", b="+b);
result[k++] = r / 255.0f;
}
}
return result;
}
}
6、在MainActivity中加载布局和调用预测工具类
package com.stormdzh.tfmnist;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.graphics.Matrix;
import android.os.Bundle;
import android.util.Log;
import android.view.View;
import android.widget.ImageView;
import android.widget.TextView;
import androidx.appcompat.app.AppCompatActivity;
import com.stormdzh.tfmnist.handwrite.MyPaintView;
public class MainActivity extends AppCompatActivity implements View.OnClickListener {
public static boolean isRegression = false; //true 使用线性模型
private static final String TAG = "MainActivity";
// private String MODEL_FILE = "file:///android_asset/mnist_dzh.pb"; //模型存放路径
// private String MODEL_FILE = "file:///android_asset/mnist_regression.pb"; //模型存放路径
private String MODEL_FILE = "file:///android_asset/mnist_convolutional.pb"; //模型存放路径
private TextView tvResult;
private ImageView imgPrevieww;
private Bitmap bitmap;
private PredictionTF preTF;
private int index = 0;
private MyPaintView mMyPaintView;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
findViewById(R.id.btnTest).setOnClickListener(this);
findViewById(R.id.btnClear).setOnClickListener(this);
findViewById(R.id.btnOk).setOnClickListener(this);
tvResult = (TextView) findViewById(R.id.tvResult);
imgPrevieww = (ImageView) findViewById(R.id.imgPrevieww);
mMyPaintView = findViewById(R.id.mMyPaintView);
getBitmap();
if (isRegression) {
MODEL_FILE = "file:///android_asset/mnist_regression.pb"; //模型存放路径
} else {
MODEL_FILE = "file:///android_asset/mnist_convolutional.pb"; //模型存放路径
}
preTF = new PredictionTF(getAssets(), MODEL_FILE);//输入模型存放路径,并加载TensoFlow模型
}
private Bitmap getBitmap() {
switch (index) {
case 0:
bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n0);
break;
case 1:
bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n1);
break;
case 2:
bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n2);
break;
case 3:
bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n3);
break;
case 4:
bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n4);
break;
case 5:
bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n5);
break;
case 6:
bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n6);
break;
case 7:
bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n7);
break;
case 8:
bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n8);
break;
case 9:
bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n9);
break;
}
imgPrevieww.setImageBitmap(bitmap);
return bitmap;
}
@Override
public void onClick(View view) {
switch (view.getId()) {
case R.id.btnTest:
if (index > 9)
index = 0;
bitmap = getBitmap();
Log.i(TAG, "sourceBitmap=>" + bitmap.getWidth() + " :" + bitmap.getHeight());
index++;
recogBitmap(bitmap);
break;
case R.id.btnClear:
mMyPaintView.clear();
break;
case R.id.btnOk:
// Bitmap viewBitmap = convertViewToBitmap(mMyPaintView);
Bitmap viewBitmap = mMyPaintView.getBitMap();
// imgPrevieww.setImageBitmap(viewBitmap);
Bitmap finalBitmap = scaledBitmap(viewBitmap);
imgPrevieww.setImageBitmap(finalBitmap);
Log.i(TAG, "finalBitmap=>" + finalBitmap.getWidth() + " :" + finalBitmap.getHeight());
recogBitmap(finalBitmap);
break;
}
}
private Bitmap scaledBitmap(Bitmap viewBitmap) {
int width = viewBitmap.getWidth();
float scale = 74f / width;
Matrix matrix = new Matrix();
matrix.setScale(scale, scale);
return Bitmap.createBitmap(viewBitmap, 0, 0, viewBitmap.getWidth(),
viewBitmap.getHeight(), matrix, true);
}
private void recogBitmap(Bitmap bitmap) {
String res = "图片识别结果为:";
int[] result = preTF.getPredict(bitmap);
for (int i = 0; i < result.length; i++) {
Log.i(TAG, res + result[i]);
res = res + String.valueOf(result[i]) + " ";
}
tvResult.setText(res);
}
public Bitmap convertViewToBitmap(View view) {
view.measure(View.MeasureSpec.makeMeasureSpec(0, View.MeasureSpec.UNSPECIFIED), View.MeasureSpec.makeMeasureSpec(0, View.MeasureSpec.UNSPECIFIED));
view.layout(0, 0, view.getMeasuredWidth(), view.getMeasuredHeight());
view.buildDrawingCache();
Bitmap bitmap = view.getDrawingCache();
return bitmap;
}
}
效果
首先看下工程运行后的界面:

点击测试按钮可以依次循环测试我添加的10中0-9的数字,这个写数字的识别率是100%。
黑色区域是手写区域,有清空和识别两个按钮,清空是清空画布,识别就是开始预测。
例如手写“4”的识别结果:

目前demo中是使用卷积模型识别的,有些数字的写的歪了等异常情况是识别有误的,这个以后还需要继续优化。代码可以参考我github工程:TFMnist
网友评论