前面提到了用CNN来做OCR。这篇文章介绍另一种做OCR的方法,就是通过LSTM+CTC。这种方法的好处是他可以事先不用知道一共有几个字符需要识别。之前我试过不用CTC,只用LSTM,效果一直不行,后来下决心加上CTC,效果一下就上去了。
CTC是序列标志的一个重要算法,它主要解决了label对齐的问题。有很多实现。百度IDL在16年初公开了一个GPU的实现,号称速度比之前的theano-ctc, stanford-ctc都要快。Mxnet目前还没有ctc的实现,因此决定吧warpctc集成进mxnet。
根据issue里作者们的建议,决定和集成torch一样,写一个plugin,因此C++代码放在plugin/warpctc目录中。整个集成任务其实就是写一个wrapctc的op。代码在 plugin/warpctc/warpctc-inl.h.
CTC这一层其实和SoftmaxOutput很像。其实他们的forward的实现就是一模一样的。唯一的差别就是backward中grad的实现,在这里需要调用warpctc的compute_ctc_loss函数来计算梯度。实际上warpctc的主要接口也就是这个函数。
下面说说具体怎么用lstm+ctc来做ocr的任务。详细的代码在 examples/warpctc/lstm_ocr.py。这里只说说大体思路。
假设我们要解决的是4位数字的识别,图片是80*30的图片。那么我们就将每张图片按列切分成80个30维的向量。然后作为一个lstm的80个输入。一个lstm的输出和输入数目应该是相同的。而我们的预测目标却只有4个数字。而不是80个数字。在没有用ctc时我想了两个解决方案。第一个是用encode-decode模式。也就是80个输入做encode,然后decode成4个输出。实测效果很挫。第二个是把4个label每个copy20遍,从而变成80个label。实测也很挫。没办法,最后只能用ctc loss了。
用ctc loss的体会就是,如果input的长度远远大于label的长度,比如我这里是80和4的关系。那么一开始的收敛会比较慢。在其中有一段时间cost几乎不变。此刻一定要有耐心,最终一定会收敛的。在ocr识别的这个例子上最终可以收敛到95%的精度。
目前代码还在等待merge。pull request。
---------------
欢迎关注 微信公众号【ResysChina】
网友评论
我输入的原始数据是手写单个文字的坐标值序列,这时候要训练网络的话,标签值应该是什么呢?像文章里您说要预测4个数字,那label就是这四个数值吗?那我想预测文字怎么办呢
求大神不吝赐教~
[21:46:11] /home/chiron/object/mxnet/dmlc-core/include/dmlc/./logging.h:300: [21:46:11] src/operator/./slice_channel-inl.h:178: Check failed: ishape[real_axis] == param_.num_outputs (3000 vs. 30) If squeeze axis is True, the size of the sliced axis must be the same as num_outputs. Input shape=(100,3000), axis=1, num_outputs=30.
Stack trace returned 10 entries:
[bt] (0) /home/chiron/object/mxnet/python/mxnet/../../lib/libmxnet.so(_ZN4dmlc15LogMessageFatalD1Ev+0x3c) [0x7f3f3bd74eec]
[bt] (1)
infer_shape error. Arguments:
label: (100, 6)
l0_init_c: (100, 128)
l1_init_h: (100, 128)
l0_init_h: (100, 128)
data: (100, 3000)
l1_init_c: (100, 128)
Traceback (most recent call last):
File "lstm_ocr_myiter.py", line 197, in <module>
batch_end_callback=mx.callback.Speedometer(BATCH_SIZE, 100))
File "../../python/mxnet/model.py", line 772, in fit
self._init_params(dict(data.provide_data+data.provide_label))
File "../../python/mxnet/model.py", line 500, in _init_params
arg_shapes, _, aux_shapes = self.symbol.infer_shape(**input_shapes)
File "../../python/mxnet/symbol.py", line 588, in infer_shape
res = self._infer_shape_impl(False, *args, **kwargs)
File "../../python/mxnet/symbol.py", line 671, in _infer_shape_impl
ctypes.byref(complete)))
File "../../python/mxnet/base.py", line 78, in check_call
raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: Error in operator slicechannel0: [21:46:11] src/operator/./slice_channel-inl.h:178: Check failed: ishape[real_axis] == param_.num_outputs (3000 vs. 30) If squeeze axis is True, the size of the sliced axis must be the same as num_outputs. Input shape=(100,3000), axis=1, num_outputs=30.
这种错误,这是为什么
lstm_ocr.py:196: DeprecationWarning: mxnet.model.FeedForward has been deprecated. Please use mxnet.mod.Module instead.
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
begin fit
Traceback (most recent call last):
File "lstm_ocr.py", line 209, in <module>
epoch_end_callback = mx.callback.do_checkpoint(prefix, 1))
NameError: name 'prefix' is not defined
WARPCTC_PATH = $(HOME)/warp-ctc
MXNET_PLUGINS += plugin/warpctc/warpctc.mk
zhouzhirui@gpu-dl-01:~/mxnet/example/warpctc$ sudo python lstm_ocr.py
lstm_ocr.py:196: DeprecationWarning: mxnet.model.FeedForward has been deprecated. Please use mxnet.mod.Module instead.
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
begin fit
../../python/mxnet/model.py:516: DeprecationWarning: Calling initializer with init(str, NDArray) has been deprecated.please use init(mx.init.InitDesc(...), NDArray) instead.
self.initializer(k, v)
2017-02-19 22:28:07,437 Start training with [gpu(0)]
iter
terminate called after throwing an instance of 'std::runtime_error'
what(): Error: compute_ctc_loss, stat = execution failed
已放弃 (核心已转储)
python toy_ctc.py
toy_ctc.py:152: DeprecationWarning: mxnet.model.FeedForward has been deprecated. Please use mxnet.mod.Module instead.
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
begin fit
../../python/mxnet/model.py:516: DeprecationWarning: Calling initializer with init(str, NDArray) has been deprecated.please use init(mx.init.InitDesc(...), NDArray) instead.
self.initializer(k, v)
2017-02-19 22:19:10,732 Start training with [gpu(0)]
terminate called after throwing an instance of 'std::runtime_error'
what(): Error: compute_ctc_loss, stat = execution failed
已放弃 (核心已转储)
toy_ctc.py:152: DeprecationWarning: mxnet.model.FeedForward has been deprecated. Please use mxnet.mod.Module instead.
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
begin fit
../../python/mxnet/model.py:516: DeprecationWarning: Calling initializer with init(str, NDArray) has been deprecated.please use init(mx.init.InitDesc(...), NDArray) instead.
self.initializer(k, v)
2017-02-19 22:22:27,257 Start training with [gpu(0)]
terminate called after throwing an instance of 'std::runtime_error'
what(): Error: compute_ctc_loss, stat = execution failed
已放弃 (核心已转储)
同样
Traceback (most recent call last):
File "toy_ctc.py", line 144, in <module>
symbol = sym_gen(SEQ_LENGTH)
File "toy_ctc.py", line 135, in sym_gen
num_label = num_label)
File "/home/deeplearning/mxnet/example/warpctc/lstm.py", line 78, in lstm_unroll
sm = mx.sym.WarpCTC(data=pred, label=label, label_length = num_label, input_length = seq_len)
AttributeError: 'module' object has no attribute 'WarpCTC'
看上去可能是因为在编译的时候没能够把warpctc编译到mxnet的原因,不知道是不是遗漏了什么步骤,还望指点。
谢谢。
非常感谢您的热心,问题已经解决了,是我自己粗心导致的。
这个函数中有参数来实现这个功能吗?还是说在accuracy这个函数里面实现呢?