美文网首页iOS学习开发
TensorFlow Lite实战——在iOS上部署中文文本分类

TensorFlow Lite实战——在iOS上部署中文文本分类

作者: QYiZHong | 来源:发表于2019-07-07 15:26 被阅读45次

    前言

    本文所使用的分类模型来自于CNN-RNN中文文本分类,基于TensorFlow,感谢开源。

    最近一段时间需要用到中文文本分类这样一个功能,于是我马上想到了Create ML,但是经过自己的尝试以后发现Create ML并不支持中文的文本分类(不信可以自己试试)。

    最近发现有道词典有离线翻译这样一个功能,我猜测这应该就是把模型下载到本地使用了,这么一看模型部署到移动端理论上是可行的。但各个深度学习框架我只了解过tensorflow,于是在有这样一个需求之下,我又回到了tensorflow这个大坑,去年年底说我这辈子都不会再用tensorflow了,没想到真香了。

    实际上tensorflow所训练的模型是放在后端最合适,但由于我不想给APP维护一个健壮的后端,所以执着于把模型部署到移动端。这个是Demo

    言归正传,从头部署一个模型我可以归纳出几个步骤

    1. 训练并测试模型,将模型保存为ckpt格式
    2. 将ckpt模型固化转成pb模型
    3. 通过TensorFlow Lite提供的方法将pb模型转换为tflite模型
    4. 使用cocoapods的方式引入TensorFlow Lite,并把模型导入工程
    5. 封装调用模型逻辑,进行文本分类

    注意: 本篇博客仅根据上方的开源工程进行部署,其他的网络结构还需要具体问题具体分析。

    大致分类原理

    如果想要从头部署一遍,一定要对tensorflow有一定了解,因为不读懂工程的源码意思是基本上无法往下流程做的。

    这个工程把每一个文本中的字符映射成一个个数字(id),通过一系列玄学操作,得到一个一维数组,其中前10个就是我们要关注的值,因为标签只有10个。

    数据处理

    我们需要了解数据处理的方式即输入和输出,这样我们才能编写代码在iOS APP中进行预测。

    输入

    这个开源工程中会把每一个字符(汉字)映射成一个id,这个id来自于数据集中的行,意思就是第一行对应的字符id就是0,第二行对应的是1,以此类推。这样我们就获得了一个id的数组。并且这个id数组需要处理成一个固定长度,本文在iOS中处理方式为不足则数组后面添0,多余则移除数组末尾。

    输出

    输出的是一个数组,数量会超过10个,但因为数据集中的分类只有10个,所以我们只需要关注这个数组的前10个即可。这前10个数组对应的下标就是标签数组中的下标,数组的值就是预测的概率。所以输出的数组0-10的下标就对应了标签数组中0-10具体分类的可能性。

    部署

    训练模型

    本文使用开源工程中的CNN网络,因为TensorFlow Lite支持的operators有限,所以不是所有的TensorFlow中的operators都支持,如果出现不支持的情况就会在转换中出现类似如下的错误:

    Some of the operators in the model are not supported by the standard TensorFlow Lite runtime. If you have a custom implementation for them you can disable this error with --allow_custom_ops, or by setting allow_custom_ops=True when calling tf.contrib.lite.toco_convert(). Here is a list of operators for which  you will need custom implementations: RandomUniform
    

    这里的错误中可以发现不支持的operator是RandomUniform。查找之后发现CNN中的tf.contrib.layers.dropout不受支持,但是这个问题不大,我们可以用L2正则化去替代它防止过拟合。

    下面是修改后的参考代码:

    # coding: utf-8
    from functools import partial
    
    import tensorflow as tf
    
    
    class TCNNConfig(object):
        """CNN配置参数"""
    
        embedding_dim = 64  # 词向量维度
        seq_length = 600  # 序列长度
        num_classes = 10  # 类别数
        num_filters = 256  # 卷积核数目
        kernel_size = 5  # 卷积核尺寸
        vocab_size = 5000  # 词汇表达小
    
        hidden_dim = 128  # 全连接层神经元
    
        dropout_keep_prob = 0.5  # dropout保留比例
        learning_rate = 1e-3  # 学习率
    
        batch_size = 64  # 每批训练大小
        num_epochs = 10  # 总迭代轮次
    
        print_per_batch = 100  # 每多少轮输出一次结果
        save_per_batch = 10  # 每多少轮存入tensorboard
    
        scale = 0.01
    
    
    class TextCNN(object):
        """文本分类,CNN模型"""
    
        def __init__(self, config):
            self.config = config
    
            # 三个待输入的数据
            self.input_x = tf.placeholder(tf.int32, [None, self.config.seq_length], name='input_x')
            self.input_y = tf.placeholder(tf.float32, [None, self.config.num_classes], name='input_y')
            self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
    
            self.cnn()
    
        def cnn(self):
            """CNN模型"""
            my_dense_layer = partial(
                tf.layers.dense, activation=tf.nn.relu,
                # 在这里传入了L2正则化函数,并在函数中传入正则化系数。
                kernel_regularizer=tf.contrib.layers.l2_regularizer(self.config.scale)
            )
            # 词向量映射
            with tf.device('/cpu:0'):
                embedding = tf.get_variable('embedding', [self.config.vocab_size, self.config.embedding_dim])
                embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x)
    
            with tf.name_scope("cnn"):
                # CNN layer
                conv = tf.layers.conv1d(embedding_inputs, self.config.num_filters, self.config.kernel_size, name='conv')
                # global max pooling layer
                gmp = tf.reduce_max(conv, reduction_indices=[1], name='gmp')
    
            with tf.name_scope("score"):
                # 全连接层
                fc = my_dense_layer(gmp, self.config.hidden_dim, name='fc1')
                # fc = tf.layers.dense(gmp, self.config.hidden_dim, name='fc1')
                # fc = tf.contrib.layers.dropout(fc, self.keep_prob)
                # fc = tf.nn.relu(fc)
    
                # 分类器
                self.logits = tf.layers.dense(fc, self.config.num_classes, name='fc2')
                self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1)  # 预测类别
    
            with tf.name_scope("optimize"):
                # 损失函数,交叉熵
                cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y)
                reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
                self.loss = tf.add_n([tf.reduce_mean(cross_entropy)] + reg_losses)
                # self.loss = tf.reduce_mean(cross_entropy)
                # 优化器
                self.optim = tf.train.AdamOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss)
    
            with tf.name_scope("accuracy"):
                # 准确率
                correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls)
                self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
    
    

    接下来在run_cnn.py中经过训练就能获得如下ckpt模型了:


    ckpt模型

    将ckpt模型固化转成pb模型

    在固化模型这一个环节,你需要通读这个开源工程才行,不然你肯定不了解它的网络结构以及它的输入和输出。这也是对iOS开发者非常不友好的地方。

    通过源码我们可以得知TextCNN这个类中的self.logits这个属性就是我们需要关注的输出,所以我们可以通过下面这段代码打印出tensor,然后找到我们需要的输出的name

    ops = sess.graph.get_operations()
            for op in ops:
                print(op)
    

    这里我们需要的name是

    output_node_names = "score/fc2/BiasAdd"
    

    参考源码:

    def freeze_graph(input_checkpoint):
        """
        :param input_checkpoint:
        :return:
        """
        # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
        # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
    
        # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
        output_node_names = "score/fc2/BiasAdd"
        saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
    
        with tf.Session() as sess:
            saver.restore(sess, input_checkpoint)  # 恢复图并得到数据
            output_graph_def = tf.graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
                sess=sess,
                input_graph_def=sess.graph_def,  # 等于:sess.graph_def
                output_node_names=output_node_names.split(",")
            )  # 如果有多个输出节点,以逗号隔开
    
            with tf.gfile.GFile(output_graph, "wb") as f:  # 保存模型
                f.write(output_graph_def.SerializeToString())  # 序列化输出
    

    input_checkpoint为你的ckpt模型路径

    将pb模型转换为tflite模型

    下面是from_frozen_graph方法的注解。这里我就要吐槽一下了,TensorFlow Lite的文档未免太敷衍了,说好的传入参数是一个[tensor],结果老报错,在打断点调试了它们库的源码情况下发现竟然要求的是传入tensor的name???

    from_frozen_graph方法注解

    这个只要没有出现operator不支持的情况就很简单,直接上源码就完了

    def convert_to_tflite():
        input_tensors = [
            "input_x"
        ]
        output_tensors = [
            "score/fc2/BiasAdd"
        ]
        converter = tf.lite.TFLiteConverter.from_frozen_graph(
            output_graph,
            input_tensors,
            output_tensors)
        converter.target_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
                                tf.lite.OpsSet.SELECT_TF_OPS]
        tflite_model = converter.convert()
        open(output_tflite_model, "wb").write(tflite_model)
    

    其中input_x是输入的name

    使用cocoapods的方式引入TensorFlow Lite

    TensorFlow Lite有好几个库,原生的需要写C++,在一顿操作之下我放弃了,完全看不懂tensor的输入嘛。还有OC封装的以及swift封装的。因为我的工程是swift写的,所以我直接用swift的TensorFlow Lite库

    按照他们的README

    pod 'TensorFlowLiteSwift'
    
    import TensorFlowLite
    

    就引入了,这一点就很友好了,比什么直接编译TensorFlow到iOS工程里那是简单的不能再简单了。

    封装调用模型逻辑,进行文本分类

    在喂数据进行预测时我们也要按照开源工程里喂数据的方式进行一番操作。调用的逻辑我们可以参考官方Example

    导入模型

    我们需要导入模型、分类和字符id,这在本文的前言中提供的demo中有体现。

    必须导入的东西

    初始化Interpreter

    private init() {
            let options = InterpreterOptions()
            do {
                // Create the `Interpreter`.
                let modelPath = Bundle.init(for: TextClassifier.self).path(forResource: "model", ofType: "tflite")!
                interpreter = try Interpreter(modelPath: modelPath, options: options)
                // Allocate memory for the model's input `Tensor`s.
                try interpreter.allocateTensors()
            } catch {
                print("Failed to create the interpreter with error: \(error.localizedDescription)")
            }
        }
    

    加载标签、id以及将字符转换为id

    private func loadLabels() {
            if let path = Bundle.init(for: TextClassifier.self).path(forResource: "labels", ofType: "txt") {
                let fileManager = FileManager.default
                let txtData = fileManager.contents(atPath: path)!
                let content = String.init(data: txtData, encoding: .utf8)
                let rowArray = content?.split(separator: "\n") ?? []
                for row in rowArray {
                    labels.append(String(row))
                }
            }
        }
        
        private func loadTextId() {
            if let path = Bundle.init(for: TextClassifier.self).path(forResource: "text_id", ofType: "txt") {
                let fileManager = FileManager.default
                let txtData = fileManager.contents(atPath: path)!
                let content = String.init(data: txtData, encoding: .utf8)
                let rowArray = content?.split(separator: "\n") ?? []
                var i = 0
                for row in rowArray {
                    textIdInfo[String(row)] = i
                    i += 1
                }
            }
        }
        
        private func transformTextToId(_ text: String) -> [Int] {
            var idArray: [Int] = []
            for str in text {
                idArray.append(textIdInfo[String(str)]!)
            }
            //根据python工程中的输入设置,超出截取前面,不足后面补0
            while idArray.count < 2400 {
                idArray.append(0)
            }
            while idArray.count > 2400 {
                idArray.removeLast()
            }
            return idArray
        }
    

    进行预测

    public func runModel(with text: String, closure: @escaping(InferenceReslutClosure)) {
            DispatchQueue.global().async {
                let idArray = self.transformTextToId(text)
                let outputTensor: Tensor
                do {
                    _ = try self.interpreter.input(at: 0)
                    let idData = Data.init(bytes: idArray, count: idArray.count)
                    try self.interpreter.copy(idData, toInputAt: 0)
                    try self.interpreter.invoke()
                    outputTensor = try self.interpreter.output(at: 0)
                } catch {
                    print("An error occurred while entering data: \(error.localizedDescription)")
                    return
                }
                let results: [Float]
                switch outputTensor.dataType {
                case .uInt8:
                    guard let quantization = outputTensor.quantizationParameters else {
                        print("No results returned because the quantization values for the output tensor are nil.")
                        return
                    }
                    let quantizedResults = [UInt8](outputTensor.data)
                    results = quantizedResults.map {
                        quantization.scale * Float(Int($0) - quantization.zeroPoint)
                    }
                case .float32:
                    results = outputTensor.data.withUnsafeBytes( { (ptr: UnsafeRawBufferPointer) in
                        [Float32](UnsafeBufferPointer.init(start: ptr.baseAddress?.assumingMemoryBound(to: Float32.self), count: ptr.count))
                    })
                default:
                    print("Output tensor data type \(outputTensor.dataType) is unsupported for this app.")
                    return
                }
                let resultArray = self.getTopN(results: results)
                DispatchQueue.main.async {
                    closure(resultArray)
                }
            }
        }
    

    首先我们需要把[Int]类型转换为Data类型提供给interpreter,可以如下方法转换

    let idData = Data.init(bytes: idArray, count: idArray.count)
    

    invoke()方法为调用模型进行预测

    我们拿到输出outputTensor以后,它的dataType中的float32类型就是我们需要的输出,这是因为在开源工程中的输出就是float32类型。这里我们需要用swift的指针去把Data类型换为[Float]类型,如下:

    results = outputTensor.data.withUnsafeBytes( { (ptr: UnsafeRawBufferPointer) in
                        [Float32](UnsafeBufferPointer.init(start: ptr.baseAddress?.assumingMemoryBound(to: Float32.self), count: ptr.count))
                    })
    

    至于上面那个.UInt8我没有搞懂是什么意思,但我想我的输出都是float32类型,所以应该是不会走上面那个case。

    最后我们通过getTopN方法取到前3个可能性最大的标签(预测值)

    private func getTopN(results: [Float]) -> [Inference] {
            //创建元组数组 [(labelIndex: Int, confidence: Float)]
            let zippedResults = zip(labels.indices, results)
            //从大到小排序并选出前resultCount个(根据python工程中的训练,只取前10个,因为分类只有10个)
            let sortedResults = zippedResults.sorted { $0.1 > $1.1 }.prefix(resultCount)
            //返回前resultCount对应的标签以及预测值
            return sortedResults.map { result in Inference.init(confidence: result.1, label: labels[result.0]) }
        }
    

    这里取的逻辑就像上述所说的,我们只关注输出一维数组的前10个元素,然后给他们排个序取最大三个值,这三个值所在的下标直接在标签数组中取值就能获得对应的预测分类

    最后

    博客只是一个预览,详细的逻辑还是需要直接看Demo

    参考

    CNN-RNN中文文本分类,基于TensorFlow
    TensorFlow for Poets 2: TFLite iOS
    【IOS/Android】TensorflowLite移动端部署
    TensorFlow Lite Swift Example
    Tensorflow Convert pb file to TFLITE using python

    相关文章

      网友评论

        本文标题:TensorFlow Lite实战——在iOS上部署中文文本分类

        本文链接:https://www.haomeiwen.com/subject/akmbhctx.html