美文网首页
PyTorch-YOLOv3训练自定义数据集

PyTorch-YOLOv3训练自定义数据集

作者: 湯木 | 来源:发表于2020-05-17 16:03 被阅读0次

  先给出参考的GitHub链接:https://github.com/eriklindernoren/PyTorch-YOLOv3
  其实在readme中已经有详细的训练自定义数据集教程,我记录一下自己在配置过程中遇到的问题,以供自己复习

配置自定义数据集详细过程

  一开始我是计划先下载coco数据集,跑一下yolov3,但是考虑到数据集本身比较大,而且跟我的工作没有什么直接关系,如果是想跑通程序的话,还不如直接跑自己的数据集,遇到什么问题就直接解决。
  我当前的数据集是行驶证和身份证,对于数据集的生成方法我会在后续补充,现在得到的就是可以用来训练的格式,都放在data/custom目录中,结构如下:

自定义数据集结构
  其中images目录下存放的都是图片,这就不多说了,labels目录中存放的是每张图片对应的标签,需要注意的是:标签的文件名和图片的文件名一一对应,例如:first_sheet_01_1.txt对应first_sheet_01_1.jpglabels的目录结构如下:
labels的目录结构
  那么 first_sheet_01_1.txt 里面是什么呢?每一行代表一个标注框的标签:类别id + 标注框中心x坐标 + 标注框中心y坐标 + 标注框宽度 + 标注框高度需要注意的是:1、这五个数据以空格间隔;2、后四个数据都在 [0,1] 范围,整张图片为1;3、类别id起始下标是0。如下图所示(我的数据集类别只有一类,所以第一列都是0):
first_sheet_01_1.txt

  classes.names这个文件中存的是每个类别的名称,上文提到了我的数据集类别是一,需要注意的是:最后务必有一行空行,否则在validation的时候会报错:IndexError: list index out of range原因在于: About a line in train.py: ap_table += [[c, class_names[c], "%.5f" % AP[i]]](解决该问题的参考链接:https://github.com/eriklindernoren/PyTorch-YOLOv3/issues/283),所以内容如下:

classes.names
其实遇到这个问题的本质是数组越界,那么另一种解决方案就是在取值的时候取到正确的范围就行,参考的解决方法如下:https://github.com/eriklindernoren/PyTorch-YOLOv3/issues/492
  接下来就是train.txt这个文件,其实这个文件我是在生成图片的时候一起生成的。因为需要train.txtvalid.txt这两个文件,而且里面的格式都是一样的,我就懒得再写代码去分割,其实这样不好,能用代码解决的事情,就不要手动去做,以后我得多加注意了。那么我是怎么做的呢?这两个文件存放的都是图片的相对路径,我就固定data/custom/images/,再加上文件名,train.txt中的内容如下(valid.txt中格式也一样,只不过每一行中最后的文件名不一样):
train.txt

运行 train.py 前补充

  上述步骤完成之后,还得生成yolov3-custom.cfg文件(这一步我也是之后才发觉的,论认真看readme的重要性!)。怎么生成这个文件呢?步骤如下:

$ cd config/                                # Navigate to config dir
$ bash create_custom_model.sh <num-classes> # Will create custom model 'yolov3-custom.cfg'

其中<num-classes>是类别数量,我的数据集是一个类别,就写1(不需要输入'<'和'>')。

报错(花絮)

  在运行train.py代码的时候遇到一些报错,与其称之为报错,还不如称之为花絮,为什么呢?那我就把实际情况重现一下:

TypeError: Caught TypeError in DataLoader worker process 0.

遇到这个问题的时候我就开始找博客,看看有没有解决方案(“其实你遇到的情况别人肯定也遇到过”,我很感谢我的师兄在我遇到问题的时候热心地帮我解决,他跟我说的这句话我会一直记得)。不出意外地看到了一样地报错问题,总共有两篇博客,当我仔细看,这两位博主也是在运行PyTorch-YOLOv3这个项目,列一下两篇博客地链接:https://blog.csdn.net/weixin_45093926/article/details/103330105
https://blog.csdn.net/qinglingLS/article/details/104411589
看到解决方案后我就立马按他们的方法去试呀,结果可想而知,因为已经说了是花絮,所以我还是没有解决问题。

解决问题的思路

  那我该怎么解决我的问题呢?上文也提到了,我遇到报错的第一选择是去搜博客,没错,那么我就借此机会整理一下我解决问题的思路:

  1. 报错提示为关键词去搜博客
  2. 去GitHub项目的 issues 搜 报错提示
  3. Google搜 报错提示

报错提示是最根本的原因,但是可能不是最直接的原因。但是我按以上三个步骤一步步做下来还是没有解决问题(通常以上三个步骤就可以解决我目前遇到的问题)。那怎么办?因为当时已经凌晨两点多了,虽然很想把这个问题解决了,但是身体不允许啊,于是我就关电脑去睡觉。但是躺在床上的我辗转反侧怎么也睡不着,我就在回忆我今天用到的Linux命令,打算稍微记一下,以后用起来顺手一点。事情的转折点来了!!!当我回想 cp 命令的时候(我数据集是另一个项目生成的,数据集就是通过 cp 命令复制到data/custom/images/目录),我猛地想起,我的labels目录中还没复制,也就是说:因为PyTorch-YOLOv3项目中没有自定义数据集的标签,所以导致TypeError: Caught TypeError in DataLoader worker process 0.这不是明摆着的花絮嘛,我当时躺在床上分析了一下之后,认定出错的原因就在这了,其他环节都没问题。第二天醒来复制了标签之后,果然运行成功

Warning

  看到控制台出现Warning,我有一点点强迫症,就记录一下解决方案,Warning如下:

UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead

  这已经很明显提示我们了:需要将int8转换为bool,我就按上述解决方案的步骤,这次我选择了第二步:去GitHub项目的 issues 搜 报错提示。很快就找到了解决方法,链接如下:https://github.com/eriklindernoren/PyTorch-YOLOv3/issues/283
utils/utils.py的第279行之后插入以下代码:

obj_mask=obj_mask.bool() # convert int8 to bool
noobj_mask=noobj_mask.bool() #convert int8 to bool

完美运行!

2020-05-26补充

  每次训练时,每张图片必须包含所有类别,即classes.names这个文件中有几个类别,训练集中每张图片必须包含所有类别,否则会报错:RuntimeError: CUDA error: device-side assert triggered

相关文章

网友评论

      本文标题:PyTorch-YOLOv3训练自定义数据集

      本文链接:https://www.haomeiwen.com/subject/artsohtx.html