美文网首页MXNET神经网络与深度学习
端到端的OCR:LSTM+CTC的实现

端到端的OCR:LSTM+CTC的实现

作者: xlvector | 来源:发表于2016-06-13 10:54 被阅读21993次

前面提到了用CNN来做OCR。这篇文章介绍另一种做OCR的方法,就是通过LSTM+CTC。这种方法的好处是他可以事先不用知道一共有几个字符需要识别。之前我试过不用CTC,只用LSTM,效果一直不行,后来下决心加上CTC,效果一下就上去了。

CTC是序列标志的一个重要算法,它主要解决了label对齐的问题。有很多实现。百度IDL在16年初公开了一个GPU的实现,号称速度比之前的theano-ctc, stanford-ctc都要快。Mxnet目前还没有ctc的实现,因此决定吧warpctc集成进mxnet。

根据issue里作者们的建议,决定和集成torch一样,写一个plugin,因此C++代码放在plugin/warpctc目录中。整个集成任务其实就是写一个wrapctc的op。代码在 plugin/warpctc/warpctc-inl.h.

CTC这一层其实和SoftmaxOutput很像。其实他们的forward的实现就是一模一样的。唯一的差别就是backward中grad的实现,在这里需要调用warpctc的compute_ctc_loss函数来计算梯度。实际上warpctc的主要接口也就是这个函数。

下面说说具体怎么用lstm+ctc来做ocr的任务。详细的代码在 examples/warpctc/lstm_ocr.py。这里只说说大体思路。

假设我们要解决的是4位数字的识别,图片是80*30的图片。那么我们就将每张图片按列切分成80个30维的向量。然后作为一个lstm的80个输入。一个lstm的输出和输入数目应该是相同的。而我们的预测目标却只有4个数字。而不是80个数字。在没有用ctc时我想了两个解决方案。第一个是用encode-decode模式。也就是80个输入做encode,然后decode成4个输出。实测效果很挫。第二个是把4个label每个copy20遍,从而变成80个label。实测也很挫。没办法,最后只能用ctc loss了。

用ctc loss的体会就是,如果input的长度远远大于label的长度,比如我这里是80和4的关系。那么一开始的收敛会比较慢。在其中有一段时间cost几乎不变。此刻一定要有耐心,最终一定会收敛的。在ocr识别的这个例子上最终可以收敛到95%的精度。

目前代码还在等待merge。pull request

---------------

欢迎关注 微信公众号【ResysChina】

相关文章

网友评论

  • c3a824d9e8e1:你好,mxnet中的mxnet_predict.py进行预测的时候,好像不支持GPU的选项,只能在CPU下跑
  • 门牙_766b:TensorFlow的ctc的decode直接给了消除空格和重复标签的最佳路径,我想得到完整的包含空格和重复标签的路径以及各个标签的概率信息,具体代码应该怎么实现
  • 020bf31ab9e8:作者您好,我最近刚刚开始学lstm的小白,在git上找了一个源码,想做一下手写文字的识别。不过第一步就有了疑问,看论文和别人的总结也没明白。

    我输入的原始数据是手写单个文字的坐标值序列,这时候要训练网络的话,标签值应该是什么呢?像文章里您说要预测4个数字,那label就是这四个数值吗?那我想预测文字怎么办呢

    求大神不吝赐教~
  • 乱乱_c7c1:你好,请问你训练模型时,刚开始进行验证的时候,解码出来的是空矩阵吗
  • d3b3eed271bb:您好,楼主,因为我电脑只有windows,而貌似百度的这个ctc是不支持的,那是不是用别的可在windows下用的ctc来取代呢?
  • 6ff75969fda3:你好我在装好后直接运行 example/warpctc的toy_ctc.py这个例子会报:
    [21:46:11] /home/chiron/object/mxnet/dmlc-core/include/dmlc/./logging.h:300: [21:46:11] src/operator/./slice_channel-inl.h:178: Check failed: ishape[real_axis] == param_.num_outputs (3000 vs. 30) If squeeze axis is True, the size of the sliced axis must be the same as num_outputs. Input shape=(100,3000), axis=1, num_outputs=30.

    Stack trace returned 10 entries:
    [bt] (0) /home/chiron/object/mxnet/python/mxnet/../../lib/libmxnet.so(_ZN4dmlc15LogMessageFatalD1Ev+0x3c) [0x7f3f3bd74eec]
    [bt] (1)

    infer_shape error. Arguments:
    label: (100, 6)
    l0_init_c: (100, 128)
    l1_init_h: (100, 128)
    l0_init_h: (100, 128)
    data: (100, 3000)
    l1_init_c: (100, 128)
    Traceback (most recent call last):
    File "lstm_ocr_myiter.py", line 197, in <module>
    batch_end_callback=mx.callback.Speedometer(BATCH_SIZE, 100))
    File "../../python/mxnet/model.py", line 772, in fit
    self._init_params(dict(data.provide_data+data.provide_label))
    File "../../python/mxnet/model.py", line 500, in _init_params
    arg_shapes, _, aux_shapes = self.symbol.infer_shape(**input_shapes)
    File "../../python/mxnet/symbol.py", line 588, in infer_shape
    res = self._infer_shape_impl(False, *args, **kwargs)
    File "../../python/mxnet/symbol.py", line 671, in _infer_shape_impl
    ctypes.byref(complete)))
    File "../../python/mxnet/base.py", line 78, in check_call
    raise MXNetError(py_str(_LIB.MXGetLastError()))
    mxnet.base.MXNetError: Error in operator slicechannel0: [21:46:11] src/operator/./slice_channel-inl.h:178: Check failed: ishape[real_axis] == param_.num_outputs (3000 vs. 30) If squeeze axis is True, the size of the sliced axis must be the same as num_outputs. Input shape=(100,3000), axis=1, num_outputs=30.

    这种错误,这是为什么
    5d99b933ccbc:请问您问题解决了吗?我也遇到了同样的问题
  • fad4b117d3b4:zhouzhirui@gpu-dl-01:~/mxnet/example/warpctc$ python lstm_ocr.py
    lstm_ocr.py:196: DeprecationWarning: mxnet.model.FeedForward has been deprecated. Please use mxnet.mod.Module instead.
    initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
    begin fit
    Traceback (most recent call last):
    File "lstm_ocr.py", line 209, in <module>
    epoch_end_callback = mx.callback.do_checkpoint(prefix, 1))
    NameError: name 'prefix' is not defined
    乱乱_c7c1:你好,请问一下,在刚开始训练的时候,解码出来的是个空向量吗
    fad4b117d3b4:PS:这两行已经去掉了注释 然后重新编译了MXNET
    WARPCTC_PATH = $(HOME)/warp-ctc
    MXNET_PLUGINS += plugin/warpctc/warpctc.mk
    fad4b117d3b4:我将prefix改为None,再运行
    zhouzhirui@gpu-dl-01:~/mxnet/example/warpctc$ sudo python lstm_ocr.py
    lstm_ocr.py:196: DeprecationWarning: mxnet.model.FeedForward has been deprecated. Please use mxnet.mod.Module instead.
    initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
    begin fit
    ../../python/mxnet/model.py:516: DeprecationWarning: Calling initializer with init(str, NDArray) has been deprecated.please use init(mx.init.InitDesc(...), NDArray) instead.
    self.initializer(k, v)
    2017-02-19 22:28:07,437 Start training with [gpu(0)]
    iter
    terminate called after throwing an instance of 'std::runtime_error'
    what(): Error: compute_ctc_loss, stat = execution failed
    已放弃 (核心已转储)
  • fad4b117d3b4:请问,下面这个错误怎么解决呢?
    python toy_ctc.py

    toy_ctc.py:152: DeprecationWarning: mxnet.model.FeedForward has been deprecated. Please use mxnet.mod.Module instead.
    initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
    begin fit
    ../../python/mxnet/model.py:516: DeprecationWarning: Calling initializer with init(str, NDArray) has been deprecated.please use init(mx.init.InitDesc(...), NDArray) instead.
    self.initializer(k, v)
    2017-02-19 22:19:10,732 Start training with [gpu(0)]
    terminate called after throwing an instance of 'std::runtime_error'
    what(): Error: compute_ctc_loss, stat = execution failed
    已放弃 (核心已转储)
    fad4b117d3b4:sudo python toy_ctc.py

    toy_ctc.py:152: DeprecationWarning: mxnet.model.FeedForward has been deprecated. Please use mxnet.mod.Module instead.
    initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
    begin fit
    ../../python/mxnet/model.py:516: DeprecationWarning: Calling initializer with init(str, NDArray) has been deprecated.please use init(mx.init.InitDesc(...), NDArray) instead.
    self.initializer(k, v)
    2017-02-19 22:22:27,257 Start training with [gpu(0)]
    terminate called after throwing an instance of 'std::runtime_error'
    what(): Error: compute_ctc_loss, stat = execution failed
    已放弃 (核心已转储)
    同样
  • AlexanderYau:你好博主,我最近在做 Street View House Numbers(http://ufldl.stanford.edu/housenumbers/)的识别,就是从街景图片中直接把数字定位并且识别出来,请问你的方法可以做吗?
  • e068ef7de442:hello, 我用你在mxnet的example里面的代码,跑你的例子程序是可以看到效果的,但是我增加了label的种类,希望这个网络还可以去识别字母,但是这样的话,输出的training accuracy一直是0,不知道是怎么回事
    e068ef7de442:@冲浪的考拉 我知道问题了...我忘了改lstm最后全连接里面label的种类个数.....忽略我这个逗比的问题吧............
  • e596d7bcf9c9:请问,我用自己的数据,数字加上一个字母识别率怎么一直是0呢,用0-9的数字没有问题,一旦加上一个字母识别率就是0了,这是怎么回事呢,先谢谢了。
  • andycpf:你好。。。我用mxnet里面的 warp-ctc中的 lstm-ocr训练自己的数据(箱号数据,4位字母和六位的数字),为什么识别率一直为0呢?
    e596d7bcf9c9:@andycpf 请问你的问题解决了吗,我也碰到了同样的问题,不知道怎么往下弄了
  • cb557600586d:你好,我用mxnet 训练好模型,加载模型然后输入一张图片测试 总是报错 input node is not complete ,是为啥
    e596d7bcf9c9:@沧浪_之水,你用mxnet训练模型,用了多长时间啊,多少个epoch?我已经训练70个epoch, 准确率才只有0.01。多谢了
    cb557600586d:@xlvector 好的,我再折腾下
    xlvector:@沧浪_之水 我好像只写了训练的代码,没有写测试的
  • 7720879eab28:您好,再次请教一个问题,ctc的识别结果的置信度怎么计算更合理呢?直接相乘感觉不够合理。
  • 4f251c3aca42:首次使用mxnet和warpctc,按照readme已经成功编译了mxnet和warpctc,编译mxnet的时候也在config中去掉了warpctc的注释。但是在运行toy_ctc.py的时候,出现以下错误:
    Traceback (most recent call last):
    File "toy_ctc.py", line 144, in <module>
    symbol = sym_gen(SEQ_LENGTH)
    File "toy_ctc.py", line 135, in sym_gen
    num_label = num_label)
    File "/home/deeplearning/mxnet/example/warpctc/lstm.py", line 78, in lstm_unroll
    sm = mx.sym.WarpCTC(data=pred, label=label, label_length = num_label, input_length = seq_len)
    AttributeError: 'module' object has no attribute 'WarpCTC'

    看上去可能是因为在编译的时候没能够把warpctc编译到mxnet的原因,不知道是不是遗漏了什么步骤,还望指点。
    谢谢。
    xlvector:@EricMoon 👍
    4f251c3aca42:@xlvector
    非常感谢您的热心,问题已经解决了,是我自己粗心导致的。
    xlvector:@EricMoon 你是否可以把你现在的代码push到git的一个branch。我帮你看看
  • 7720879eab28:再请教一个问题,对于不定长的输入序列,如果用bucketing这种方式,每个bucket中padding的数据会对ctc的效果有影响吗?或者您打算用何种方式处理不定长的输入序列呢?非常感谢
    7720879eab28:@xlvector 您好,你说的“调用warpctc的时候截取空格前的”,这个怎么实现呢?mx.sym.WarpCTC(data=pred, label=label, label_length = num_label, input_length = seq_len)
    这个函数中有参数来实现这个功能吗?还是说在accuracy这个函数里面实现呢?
    xlvector:@arestorres 我正在实现,这周末可以搞定
    xlvector:@arestorres 简单的可以这么做。label_length 作为最长的label的长度。然后如果长度短的,就在后面补空格。然后在调用warpctc的时候截取空格前的。
  • 7720879eab28:请教一个问题,对于label不定长的问题,输入数据需要预先归一化到统一尺寸吗?
    7720879eab28:@xlvector 都要在哪里修改呢?目前对ctc的理解还不是很深刻...
    xlvector:@arestorres 如果没人改,等我有空了我会自己改。
    xlvector:@arestorres 目前我的实现还没有支持不定长的。我正在等待有人基于我这个上面修改。这个修改应该不难。
  • 19beab2c37c6:不好意思,接上面的问题,考虑一个应用场景,手写中文的识别,分割是个老大难的问题,单子识别我可以用CNN做到很高的精度,分割出一行文字也不是特别困难。一行文字的个数不确定,这个问题能用CNN+ctc来解决吗?
    xlvector:@shuokay 是的。好多论文都是这么解决的。ctc在识别问题上主要就是解决这类问题
  • 19beab2c37c6:请问,就像文章中说的,ctc可以解决事先不知道label长度的情况。文章中/代码中,label的长度是4,如果,label长度不确定,可以直接复用您的代码吗?或者是否有其他解决方法?
    xlvector:@shuokay 是的。我也是玩票性质。本身不是研究ocr的。所以只是抛了个砖。等着引玉。:smile:
    19beab2c37c6:OK,我研究一下,我主要玩CNN,对rnn不是特别熟悉,所以可能得花点时间。另外,问一下,您是不是搞推荐的那个项亮?
    xlvector:@shuokay 需要改。但是不难改。你可以试试写一个不定长的例子。然后提交一个pr。
  • 1031ecea4039:恒定误差传送带啊
    xlvector:@疯狂的程序员 干ctc一定要睡觉。因为结果只会在睡醒后才会出来。坐着干等会疯的。

本文标题:端到端的OCR:LSTM+CTC的实现

本文链接:https://www.haomeiwen.com/subject/vixydttx.html