标签:文本分类
作者:炼己者
本博客所有内容以学习、研究和分享为主,如需转载,请联系本人,标明作者和出处,并且是非商业用途,谢谢!
半监督学习文本分类系列
用半监督算法做文本分类(sklearn)
sklearn半监督学习
基于自训练的半监督文本分类算法
基于协同训练的半监督文本分类算法
好久没有更新博客了,最近一直在忙着学c++,把python版本的项目转为c++,对于小白的我真的好难啊。每天下班后先学c++,尽管学了这么久,但是距离把python版本的项目转为c++还是有段距离,再接再厉吧!
一. 摘要
之前拿到了80万条数据,要做广告识别,简单的说就是拿到一段文本,让模型判断它是不是广告。这就是一个二分类的问题。但是80万条数据是没有标签的啊,我总不能把它们标注完然后去训练模型吧。所以研究了很多半监督学习方法,试图用半监督学习来解决这个问题。
半监督学习算法效果并不好,精度太低了。最后想到了一个好方法,先标注2万条数据,然后用textcnn训练这份标注好的数据,预测出文本的标签。最后我根据这些标签来排查。简单来说就是先让机器替我标注数据,然后我再进行复核。虽然有错误,不过大大地加快了我标注数据的速度。最后我一共复核了8万条数据,然后训练模型,最终效果还可以,准确率是0.95
二.操作
简单说下怎么使用textcnn,这份代码来自于github上的一个大佬,结构很简单,很好理解的一份代码,基本上只要会一些tensorflow的基本语法,再稍微理解一下cnn网络的原理,就可以理解这份代码了。后期我会尝试着把自己对tensorflow和cnn的理解分享给大家。关键还是怎么把这份代码用在自己的数据集上,这样才能让自己放心。没用起来,总会觉得心里有点虚。
代码源于这里cnn和rnn,它里面介绍了cnn和rnn两种文本分类方法。用法是一样的。不过里面文件有点多,可能看着会有点懵逼。所以我简化了一下,大家可以直接参考这里。
点开这个链接(https://github.com/lianjizhe/adver-project),然后clone一下,打开文件夹,发现有这些文件
先说一下各个文件夹是什么意思,checkpoints文件夹里放着是你训练好的模型(当你运行run_cnn.py之后会产生的文件),data文件夹里放着需要处理的数据,cnn_model.py里是构建cnn模型的py文件,
run_cnn.py这里就是我们要运行的文件。
将代码用到自己数据集的流程:
1.打开data文件夹,再打开里面的cnews文件夹,你会看到里面有四个文件
cnews.test.txt是测试数据集,cnews.train.txt是训练数据集,cnews.val.txt是验证数据集,cnews.vocab.txt是词汇表。你所需要替换的只有前三个文件。按里面的数据格式替换即可。一共是两列,一列是标签,另一列就是文本数据。文件名也不要改动。
由于github貌似传不了大文件,我就在这里给大家展示里面的数据格式,三个文件的数据格式是一样的。
体育 马晓旭意外受伤让国奥警惕
体育 商瑞华首战复仇心切
这里提示一下,训练数据集最好分布均衡,比如我的广告数据集,就是做二分类。那么假设训练集一共10000条数据,我就要找到5000条广告数据和5000条非广告数据来做训练集。
- 数据替换完毕之后,要更改代码了。代码只需要改动两个地方就可以了。打开cnn_model.py文件,找到下面代码
改动里面的num_classes就可以了,做几分类就改成多少。
然后打开data文件夹里面的cnews_loader.py文件,找到下面代码
改掉categories里的标签,把它们改成你自己设置的标签就好了
- 到这里就结束了。然后cmd,切换到run_cnn.py文件所在目录,敲入
python run_cnn.py train
这样就是训练模型了,训练好之后,模型会自动保存到checkpoints文件夹里面。
测试模型的效果输入以下命令即可
python run_cnn.py test
然后你就会得到模型的效果了。
三.总结
- 很多任务其实不用你从头到尾把代码写一遍,只要能利用别人写好的代码完成你的任务就好,你要去理解别人的代码,然后做到随意改动,得到你想要的东西,这很关键,学别人的代码会让你成长的更快。
- 我们大部分人做项目都是尽力去保证精度,很少会去考虑时间性能。你训练好的模型测试数据时一秒钟可以跑几条?这在工作的时候很重要,因为效率至上。所以在优化的时候要好好注意!
通过这个项目,给大家几个建议:
- 当你拿到一批没有标注的数据时,第一步就是先标注3万条数据,跑跑模型,测测效果。然后开始利用训练好的模型去预测未标注数据的标签。然后通过复核标签来增加你的训练集,这样做就可以提升你标注数据的速度。
- 训练集要保证数据均衡,比如广告数据和非广告数据的数量要相等。大部分数据肯定是非广告数据,你在复核数据的时候可以直接取出预测标签为广告的数据进行复核。
- 时间性能的提升我这里做了两步,首先是降低了每个句子的维度,一开始设置的是600维,后来改成了300维。其次是对数据进行了分析,把所有的数据按字数做了分划,最后发现15个字以内的句子都是非广告数据,而且15个字以内的文本占全部文本的80%,这就说明我的数据有80%是非广告数据啊。这样在做测试的时候先对每个句子的长度进行判断,如果大于15个字再用模型预测,如果小于15个字直接判定为非广告。
- 通过上面两步的操作,时间性能得到了提升,最终一秒钟可以跑7000条数据
如果大家还有问题可以在博客下面提问,我们一起探讨!
以下是我所有文章的目录,大家如果感兴趣,也可以前往查看
👉戳右边:打开它,也许会看到很多对你有帮助的文章
网友评论