美文网首页
Python+Android进行TensorFlow开发

Python+Android进行TensorFlow开发

作者: 温驭臣 | 来源:发表于2018-12-03 13:54 被阅读0次

    Tensorflow是Google开源的一套机器学习框架,支持GPU、CPU、Android等多种计算平台。本文将介绍在Tensorflow在Android上的使用。

    Android使用Tensorflow框架需要引入两个文件libtensorflow_inference.so、libandroid_tensorflow_inference_java.jar。这两个文件可以使用官方预编译的文件。如果预编译的so不满足要求(比如不支持训练模型中的某些操作符运算),也可以自己通过bazel编译生成这两个文件。

    将libandroid_tensorflow_inference_java.jar放在app下的libs目录下,so文件命名为libtensorflow_jni.so放在src/main/jniLibs目录下对应的ABI文件夹下。目录结构如下:

    Android目录结构

    同时在app的build.gradle中的dependencies模块下添加如下配置:

    dependencies {

    ...

    compile files('libs/libandroid_tensorflow_inference_java.jar')

    ...

    }

    使用tensorflow框架进行机器学习分为四个步骤:

    构造神经网络

    训练神经网络模型

    将训练好的模型输出为pb文件

    在Android上加载pb模型进行计算

    前三步是模型的构造,我们通过python实现,下面给出了一个二分类的简单模型的构造过程,首先是训练过程:

    # -*-coding:utf-8 -*-

    from__future__importprint_function

    importos

    importtensorflowastf

    fromnumpy.randomimportRandomState

    os.environ['TF_CPP_MIN_LOG_LEVEL'] ='2'

    """

    训练模型

    """

    deftrain():

    # 定义训练数据集batch大小为8

    batch_size =8

    # 定义神经网络参数,参数体现出神经网络结构,一个输入层,一个输出层,一个隐藏层

    w1 = tf.Variable(tf.random_normal([2,3], stddev=1, seed=1), name="w1_val")

    w2 = tf.Variable(tf.random_normal([3,1], stddev=1, seed=1), name="w2_val")

    # 定义输入输出格式

    x = tf.placeholder(tf.float32, shape=(None,2), name='x_input')

    y_ = tf.placeholder(tf.float32, shape=(None,1))

    # 定义神经网络前向传播过程

    a = tf.matmul(x, w1)

    y = tf.matmul(a, w2, name="cal_node")

    # 定义交叉熵和反向传播算法

    cross_entropy = -tf.reduce_mean(y_ * tf.log(tf.clip_by_value(y,1e-10,1.0)))

    train_step = tf.train.AdadeltaOptimizer(0.001).minimize(cross_entropy)

    # 生成随机训练集

    rdm = RandomState(1)

    dataset_size =128

    # 定义映射关系

    X = rdm.rand(dataset_size,2)

    Y = [[int(x1 + x2 <1)]for(x1, x2)inX]

    withtf.Session()assess:

    # 初始化所有参数

    init_op = tf.global_variables_initializer()

    sess.run(init_op)

    # print sess.run(w1)

    # print sess.run(w2)

    STEPS =500

    foriinrange(STEPS):

    start = (i * batch_size) % dataset_size

    end = min(start + batch_size, dataset_size)

    # 训练神经网络,更新神经网络参数

    sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]})

    ifi %100==0:

    total_cross_entropy = sess.run(cross_entropy, feed_dict={x: X, y_: Y})

    print("After %d training step(s), cross entropy on all data is %g"% (i, total_cross_entropy))

    print(sess.run(w1))

    print(sess.run(w2))

    # 保存check point

    saver = tf.train.Saver(tf.trainable_variables())

    saver.save(sess,'./model/checpt')

    上面的代码首先定义神经网络,初始化训练数据,进行500次训练过程,并将训练结果checkpoints保存到model文件夹下,checkpoints包含了训练模型得到的参数信息,共生成四个相关的文件,如下图:

    checkpoint相关文件

    由于checkpoint文件众多,为了方便使用,我们通过下面的代码将它们生成一个pb文件,在android上只需要这个pb文件即可使用这个训练好的模型:

    """

    存储pb模型

    """

    defdump_graph_to_pb(pb_path):

    withtf.Session()assess:

    check_point = tf.train.get_checkpoint_state("./model/")

    ifcheck_point:

    saver = tf.train.import_meta_graph(check_point.model_checkpoint_path +'.meta')

    saver.restore(sess, check_point.model_checkpoint_path)

    else:

    raiseValueError("Model load failed from {}".format(check_point.model_checkpoint_path))

    graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(),"cal_node".split(","))

    withtf.gfile.GFile(pb_path,"wb")asf:

    f.write(graph_def.SerializeToString())

    拿到生成的pb模型,我们可以在android上使用了。将pb文件在这main/assets下:

    接下来就可以载入pb,进行计算了:

    publicclassMainActivityextendsAppCompatActivity{

    privateGraph graph_;

    privateSession session_;

    privateAssetManager assetManager;

    privatestaticExecutorService executorService;

    privatestaticHandler handler;

    @Override

    protectedvoidonCreate(Bundle savedInstanceState){

    super.onCreate(savedInstanceState);

    setContentView(R.layout.activity_main);

    executorService = Executors.newFixedThreadPool(5);

    // 初始化tensorflow

    initTensorFlow("outmodel.pb");

    // 使用tensorflow进行计算

    runTensorFlow();

    }

    ...

    }

    通过如下方式载入pb模型,初始化tensorflow:

    privateboolean initTensorFlow(String modelFile) {

    assetManager = getAssets();

    // 新建Graph

    graph_ = new Graph();

    InputStreamis=null;

    try{

    // 读取Assets pb文件

    is= assetManager.open(modelFile);

    }catch(IOException e) {

    e.printStackTrace();

    returnfalse;

    }

    try{

    // 加载pb到Graph

    TensorUtil.loadGraph(is, graph_);

    is.close();

    }catch(IOException e) {

    e.printStackTrace();

    returnfalse;

    }

    // 初始化session

    session_ = new Session(graph_);

    if(session_ ==null) {

    returnfalse;

    }

    returntrue;

    }

    然后就可以使用tensorflow API进行运算了:

    private void runTensorFlow() {

    executorService.execute(generatePredictRunnable(handler));

    }

    private Runnable generatePredictRunnable(Handler handler) {

    return new Runnable() {

    @Override

    public void run() {

    float[][] input = new float[1][2];

    input[0][0] = 1;

    input[0][1] = 2;

    // 定义输入tensor

    Tensor inputTensor = Tensor.create(input);

    // 指定输入,输出节点,运行并得到结果

    Tensor resultTensor = session_.runner()

    .feed("x_input", inputTensor)

    .fetch("cal_node")

    .run()

    .get(0);

    float[][] dst = new float[1][1];

    resultTensor.copyTo(dst);

    // 处理结果

    ArrayListresultList = new ArrayList<>();

    for (float val : dst[0]) {

    if (val != 0) {

    resultList.add(val);

    } else {

    break;

    }

    }

    }

    };

    }

    上面就是通过python训练机器学习模型,并在android平台进行调用的完整流程。

    相关文章

      网友评论

          本文标题:Python+Android进行TensorFlow开发

          本文链接:https://www.haomeiwen.com/subject/zgdycqtx.html