先给出参考的GitHub链接:https://github.com/eriklindernoren/PyTorch-YOLOv3
其实在readme中已经有详细的训练自定义数据集教程,我记录一下自己在配置过程中遇到的问题,以供自己复习
配置自定义数据集详细过程
一开始我是计划先下载coco数据集,跑一下yolov3,但是考虑到数据集本身比较大,而且跟我的工作没有什么直接关系,如果是想跑通程序的话,还不如直接跑自己的数据集,遇到什么问题就直接解决。
我当前的数据集是行驶证和身份证,对于数据集的生成方法我会在后续补充,现在得到的就是可以用来训练的格式,都放在data/custom
目录中,结构如下:
其中
images
目录下存放的都是图片,这就不多说了,labels
目录中存放的是每张图片对应的标签,需要注意的是:标签的文件名和图片的文件名一一对应,例如:first_sheet_01_1.txt对应first_sheet_01_1.jpg。labels
的目录结构如下:labels的目录结构
那么 first_sheet_01_1.txt 里面是什么呢?每一行代表一个标注框的标签:
类别id + 标注框中心x坐标 + 标注框中心y坐标 + 标注框宽度 + 标注框高度
,需要注意的是:1、这五个数据以空格间隔;2、后四个数据都在 [0,1] 范围,整张图片为1;3、类别id起始下标是0。如下图所示(我的数据集类别只有一类,所以第一列都是0):first_sheet_01_1.txt
classes.names
这个文件中存的是每个类别的名称,上文提到了我的数据集类别是一,需要注意的是:最后务必有一行空行,否则在validation的时候会报错:IndexError: list index out of range
原因在于: About a line in train.py: ap_table += [[c, class_names[c], "%.5f" % AP[i]]]
(解决该问题的参考链接:https://github.com/eriklindernoren/PyTorch-YOLOv3/issues/283),所以内容如下:
其实遇到这个问题的本质是数组越界,那么另一种解决方案就是在取值的时候取到正确的范围就行,参考的解决方法如下:https://github.com/eriklindernoren/PyTorch-YOLOv3/issues/492
接下来就是
train.txt
这个文件,其实这个文件我是在生成图片的时候一起生成的。因为需要train.txt
和valid.txt
这两个文件,而且里面的格式都是一样的,我就懒得再写代码去分割,其实这样不好,能用代码解决的事情,就不要手动去做,以后我得多加注意了。那么我是怎么做的呢?这两个文件存放的都是图片的相对路径,我就固定data/custom/images/
,再加上文件名,train.txt
中的内容如下(valid.txt
中格式也一样,只不过每一行中最后的文件名不一样):train.txt
运行 train.py 前补充
上述步骤完成之后,还得生成yolov3-custom.cfg
文件(这一步我也是之后才发觉的,论认真看readme的重要性!)。怎么生成这个文件呢?步骤如下:
$ cd config/ # Navigate to config dir
$ bash create_custom_model.sh <num-classes> # Will create custom model 'yolov3-custom.cfg'
其中<num-classes>
是类别数量,我的数据集是一个类别,就写1(不需要输入'<'和'>')。
报错(花絮)
在运行train.py
代码的时候遇到一些报错,与其称之为报错,还不如称之为花絮,为什么呢?那我就把实际情况重现一下:
TypeError: Caught TypeError in DataLoader worker process 0.
遇到这个问题的时候我就开始找博客,看看有没有解决方案(“其实你遇到的情况别人肯定也遇到过”,我很感谢我的师兄在我遇到问题的时候热心地帮我解决,他跟我说的这句话我会一直记得)。不出意外地看到了一样地报错问题,总共有两篇博客,当我仔细看,这两位博主也是在运行PyTorch-YOLOv3这个项目,列一下两篇博客地链接:https://blog.csdn.net/weixin_45093926/article/details/103330105
https://blog.csdn.net/qinglingLS/article/details/104411589
看到解决方案后我就立马按他们的方法去试呀,结果可想而知,因为已经说了是花絮,所以我还是没有解决问题。
解决问题的思路
那我该怎么解决我的问题呢?上文也提到了,我遇到报错的第一选择是去搜博客,没错,那么我就借此机会整理一下我解决问题的思路:
- 以报错提示为关键词去搜博客
- 去GitHub项目的 issues 搜 报错提示
- Google搜 报错提示
报错提示是最根本的原因,但是可能不是最直接的原因。但是我按以上三个步骤一步步做下来还是没有解决问题(通常以上三个步骤就可以解决我目前遇到的问题)。那怎么办?因为当时已经凌晨两点多了,虽然很想把这个问题解决了,但是身体不允许啊,于是我就关电脑去睡觉。但是躺在床上的我辗转反侧怎么也睡不着,我就在回忆我今天用到的Linux命令,打算稍微记一下,以后用起来顺手一点。事情的转折点来了!!!当我回想 cp 命令的时候(我数据集是另一个项目生成的,数据集就是通过 cp 命令复制到data/custom/images/
目录),我猛地想起,我的labels
目录中还没复制,也就是说:因为PyTorch-YOLOv3项目中没有自定义数据集的标签,所以导致TypeError: Caught TypeError in DataLoader worker process 0.
这不是明摆着的花絮嘛,我当时躺在床上分析了一下之后,认定出错的原因就在这了,其他环节都没问题。第二天醒来复制了标签之后,果然运行成功。
Warning
看到控制台出现Warning,我有一点点强迫症,就记录一下解决方案,Warning如下:
UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead
这已经很明显提示我们了:需要将int8转换为bool,我就按上述解决方案的步骤,这次我选择了第二步:去GitHub项目的 issues 搜 报错提示。很快就找到了解决方法,链接如下:https://github.com/eriklindernoren/PyTorch-YOLOv3/issues/283
在utils/utils.py
的第279行之后插入以下代码:
obj_mask=obj_mask.bool() # convert int8 to bool
noobj_mask=noobj_mask.bool() #convert int8 to bool
完美运行!
2020-05-26补充
每次训练时,每张图片必须包含所有类别,即classes.names
这个文件中有几个类别,训练集中每张图片必须包含所有类别,否则会报错:RuntimeError: CUDA error: device-side assert triggered
网友评论