美文网首页ndk
【IOS/Android】TensorflowLite移动端部署

【IOS/Android】TensorflowLite移动端部署

作者: ItchyHiker | 来源:发表于2018-12-15 17:29 被阅读619次

    记录如何在IOS上使用TensorflowLite部署自己的深度学习模型,后面考虑加入Android,参考TensorflowLite官网的实例。

    环境配置

    在自己的python 环境中使用pip 按照好 tensorflow:

    pip3 install tensorflow
    

    从github 下载工程文件:

    git clone https://github.com/googlecodelabs/tensorflow-for-poets-2
    

    下载数据集:

    wget http://download.tensorflow.org/example_images/flower_photos.tgz
    

    该数据集包含5种不同的花类型,我们用来训练模型判断花的种类
    下载后解压到tensorflow-for-poets-2/tf_files/ 路径下:

    Screen Shot 2018-12-15 at 5.04.17 PM.png

    模型训练

    在scripts路径下包含了几个脚本文件,其中retrain.py文件用于使用tensorflow 在 imagenet 数据集上训练好的 Inception和 mobilenet模型(运行的时候会自动下载)重新训练用于我们的花类型分类任务, 里面也包含了大量的可以设置的参数:

      --architecture ARCHITECTURE
                            Which model architecture to use. 'inception_v3' is the
                            most accurate, but also the slowest. For faster or
                            smaller models, chose a MobileNet with the form
                            'mobilenet_<parameter size>_<input_size>[_quantized]'.
                            For example, 'mobilenet_1.0_224' will pick a model
                            that is 17 MB in size and takes 224 pixel input
                            images, while 'mobilenet_0.25_128_quantized' will
                            choose a much less accurate, but smaller and faster
                            network that's 920 KB on disk and takes 128x128
                            images. See
                            https://research.googleblog.com/2017/06/mobilenets-
                            open-source-models-for.html for more information on
                            Mobilenet.
    

    训练脚本:

    python scripts/retrain.py \
    --output_graph=tf_files/retrained_graph.pb \
    --output_labels=tf_files/retrained_labels.txt \
    --image_dir=tf_files/flower_photos \
    --architecture=mobilenet_1.0_224  \
    --summaries_dir tf_files/training_summaries/mobilenet_1.0_244
    
    Screen Shot 2018-12-15 at 5.16.05 PM.png

    打开tensorboard可以查看finetune过程中的loss/accuracy的变化曲线:

    tensorboard --logdir=tf_files/training_summaries/mobilenet_1.0_244
    
    Screen Shot 2018-12-15 at 5.22.48 PM.png

    模型转换

    将训练好的静态图文件转换为tflite模型的时候我们使用google官方提供的转换工具toco, 关于toco的介绍可以查看我的另一篇文章Tensorflow移动端模型转换

    IMAGE_SIZE=224
    toco \
      --graph_def_file=tf_files/retrained_graph.pb \
      --output_file=tf_files/optimized_graph.lite \
      --output_format=TFLITE \
      --input_shape=1,${IMAGE_SIZE},${IMAGE_SIZE},3 \
      --input_array=input \
      --output_array=final_result \
      --inference_type=FLOAT \
      --inference_input_type=FLOAT 
    

    衡量tflite模型的准确度

    实际上在转换模型的过程中我们的模型的精度会有一定损失,获得转换好的tflite模型之后,我们还是希望能够能够先衡量下转换好的模型精度,这需要直接在python脚本中调用tflite模型解释器,然后在测试数据集上计算tflite模型的精度:
    下面给出一个调用的脚本(tensorflow接口变换很快,不保证可用):

    import numpy as np
    import tensorflow as tf
    from skimage.transform import resize
    import cv2
    import os
    
    def predict(interpreter, input_shape, input_data):
    
        """generate softmax predictions for input_data
        interpreter: the enviroment to run model
        input_shape: config information for resize input_data
        input_data: user data
        """
        input_data = resize(img, input_shape[1:])
        input_data = input_data.reshape(input_shape)
        input_data = input_data.astype("float32")
        # input_data = (input_data - 127.5) / 127.5
        interpreter.set_tensor(input_details[0]['index'], input_data)
        interpreter.invoke()
        output_data = interpreter.get_tensor(output_details[0]['index'])
        index = np.argmax(output_data)
        return index
    
    
    if __name__ == "__main__":
        # Load TFLite model and allocate tensors.
        interpreter = tf.contrib.lite.Interpreter(model_path="tf_files/optimized_graph.tflite")
        interpreter.allocate_tensors()
    
        # Get input and output tensors.
        input_details = interpreter.get_input_details()
        output_details = interpreter.get_output_details()
    
        # Test model on customer data
        input_shape = input_details[0]['shape']
    
        # load sub classes
        data_path = "/Users/yuhua.cheng/Opt/temp/tensorflow-for-poets-2/tf_files/flower_photos"
        sub_classes = [f for f in sorted(os.listdir(data_path))if os.path.isdir(os.path.join(data_path, f))]
        print(sub_classes)
        count = 0
        total = 0
        for label, sub_class in enumerate(sub_classes):
            print("processing: ", sub_class)
            sub_path = os.path.join(data_path, sub_class)
            img_files = [f for f in os.listdir(sub_path) if not f.startswith('.')]
            for img_file in img_files:
                img = cv2.imread(os.path.join(sub_path, img_file), -1)
                pred = predict(interpreter, input_shape, img)
                if pred == label:
                    count += 1
                total += 1
        print('accuracy:', count / total)
    

    在IOS工程调用tflite模型

    先安装必要的相关文件:

    xcode-select --install
    sudo gem install cocoapods
    pod install --project-directory=ios/tflite/
    

    打开IOS工程:

    open ios/tflite/tflite_camera_example.xcworkspace
    

    将模型文件和label文件复制到工程对应路径:

    cp tf_files/optimized_graph.lite ios/tflite/data/graph.lite
    cp tf_files/retrained_labels.txt ios/tflite/data/labels.txt
    

    连接手机直接运行:
    在手机上复现的结果:


    IMG_0014.PNG

    ---------后面会加入在官方教程的基础上转换以及调用自己训练好的模型结果-------

    问题记录

    1. toco 将原有的simplenet.pb转换为tflite的时候报错:
      原始模型结构:
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    input_1 (InputLayer)         (None, 227, 227, 3)       0         
    _________________________________________________________________
    block1_0_conv (Conv2D)       (None, 76, 76, 64)        9408      
    _________________________________________________________________
    block1_0_bn (BatchNormalizat (None, 76, 76, 64)        192       
    _________________________________________________________________
    block1_0_relu (Activation)   (None, 76, 76, 64)        0         
    _________________________________________________________________
    block1_0_dropout (Dropout)   (None, 76, 76, 64)        0         
    _________________________________________________________________
    block1_1_conv (Conv2D)       (None, 76, 76, 32)        18432     
    _________________________________________________________________
    block1_1_bn (BatchNormalizat (None, 76, 76, 32)        96        
    _________________________________________________________________
    block1_1_relu (Activation)   (None, 76, 76, 32)        0         
    _________________________________________________________________
    block1_1_dropout (Dropout)   (None, 76, 76, 32)        0         
    _________________________________________________________________
    block2_0_conv (Conv2D)       (None, 76, 76, 32)        9216      
    _________________________________________________________________
    block2_0_bn (BatchNormalizat (None, 76, 76, 32)        96        
    _________________________________________________________________
    block2_0_relu (Activation)   (None, 76, 76, 32)        0         
    _________________________________________________________________
    block2_0_dropout (Dropout)   (None, 76, 76, 32)        0         
    _________________________________________________________________
    block2_1_conv (Conv2D)       (None, 76, 76, 32)        9216      
    _________________________________________________________________
    block2_1_bn (BatchNormalizat (None, 76, 76, 32)        96        
    _________________________________________________________________
    max_pooling2d_1 (MaxPooling2 (None, 38, 38, 32)        0         
    _________________________________________________________________
    block2_1_relu (Activation)   (None, 38, 38, 32)        0         
    _________________________________________________________________
    block2_1_dropout (Dropout)   (None, 38, 38, 32)        0         
    _________________________________________________________________
    block2_2_conv (Conv2D)       (None, 38, 38, 32)        9216      
    _________________________________________________________________
    block2_2_bn (BatchNormalizat (None, 38, 38, 32)        96        
    _________________________________________________________________
    block2_2_relu (Activation)   (None, 38, 38, 32)        0         
    _________________________________________________________________
    block2_2_dropout (Dropout)   (None, 38, 38, 32)        0         
    _________________________________________________________________
    block3_0_conv (Conv2D)       (None, 38, 38, 32)        9216      
    _________________________________________________________________
    block3_0_bn (BatchNormalizat (None, 38, 38, 32)        96        
    _________________________________________________________________
    block3_0_relu (Activation)   (None, 38, 38, 32)        0         
    _________________________________________________________________
    block3_0_dropout (Dropout)   (None, 38, 38, 32)        0         
    _________________________________________________________________
    block4_0_conv (Conv2D)       (None, 38, 38, 64)        18432     
    _________________________________________________________________
    max_pooling2d_2 (MaxPooling2 (None, 19, 19, 64)        0         
    _________________________________________________________________
    block4_0_bn (BatchNormalizat (None, 19, 19, 64)        192       
    _________________________________________________________________
    block4_0_relu (Activation)   (None, 19, 19, 64)        0         
    _________________________________________________________________
    block4_0_dropout (Dropout)   (None, 19, 19, 64)        0         
    _________________________________________________________________
    block4_1_conv (Conv2D)       (None, 19, 19, 64)        36864     
    _________________________________________________________________
    block4_1_bn (BatchNormalizat (None, 19, 19, 64)        192       
    _________________________________________________________________
    block4_1_relu (Activation)   (None, 19, 19, 64)        0         
    _________________________________________________________________
    block4_1_dropout (Dropout)   (None, 19, 19, 64)        0         
    _________________________________________________________________
    block4_2_conv (Conv2D)       (None, 19, 19, 64)        36864     
    _________________________________________________________________
    block4_2_bn (BatchNormalizat (None, 19, 19, 64)        192       
    _________________________________________________________________
    max_pooling2d_3 (MaxPooling2 (None, 9, 9, 64)          0         
    _________________________________________________________________
    block4_2_relu (Activation)   (None, 9, 9, 64)          0         
    _________________________________________________________________
    block4_2_dropout (Dropout)   (None, 9, 9, 64)          0         
    _________________________________________________________________
    cccp4 (Conv2D)               (None, 9, 9, 256)         16640     
    _________________________________________________________________
    cccp5 (Conv2D)               (None, 9, 9, 64)          16448     
    _________________________________________________________________
    poolcp5 (MaxPooling2D)       (None, 4, 4, 64)          0         
    _________________________________________________________________
    cccp6 (Conv2D)               (None, 4, 4, 64)          36928     
    _________________________________________________________________
    poolcp6 (GlobalMaxPooling2D) (None, 64)                0         
    _________________________________________________________________
    dense_1 (Dense)              (None, 10)                650       
    _________________________________________________________________
    activation_1 (Activation)    (None, 10)                0         
    =================================================================
    Total params: 228,778
    Trainable params: 227,946
    Non-trainable params: 832
    _________________________________________________________________
    

    转换问题:

    Some of the operators in the model are not supported by the standard TensorFlow Lite runtime. If you have a custom implementation for them you can disable this error with --allow_custom_ops, or by setting allow_custom_ops=True when calling tf.contrib.lite.toco_convert(). Here is a list of operators for which  you will need custom implementations: Max.\n'
    

    问题原因: keras里面一些层使用Tensorflow封装,在Tensorflow 转换为tflite的时候不完全支持: https://github.com/tensorflow/tensorflow/issues/20042
    拟解决的方案: 在tensorflow中,使用tensorflow自己的实现重新实现一遍。
    更新tensorflow 版本从1.10到1.12问题解决, 成功转换

    pip install --upgrade tensorflow
    
    1. xcode 调用tflite报错:
    Op builtin_code out or range: 82. Are you using old TFLite binary with newer model?
    Registration failed.
    

    打断点发现问题出在:

     tflite::InterpreterBuilder(*model, resolver)(&interpreter);
    

    最后发现将第一个卷积层stride 3 改为stride 2便可,可能是TFLite中没有相应的stride 3 实现。

    Reference

    1. 如何在IOS上部署自己的深度学习模型(Tensorflow官方例子):
      https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2-ios/#0
    2. https://v-play.net/cross-platform-development/machine-learning-add-image-classification-for-ios-and-android-with-qt-and-tensorflow
    3. https://heartbeat.fritz.ai/neural-networks-on-mobile-devices-with-tensorflow-lite-a-tutorial-85b41f53230c
    4. 如何进行模型量化: https://www.tensorflow.org/lite/performance/post_training_quantization
    5. tensorflow 模型和 tflite模型 准确度不一致: https://github.com/tensorflow/tensorflow/issues/21921

    相关文章

      网友评论

        本文标题:【IOS/Android】TensorflowLite移动端部署

        本文链接:https://www.haomeiwen.com/subject/azbttqtx.html