简介
本节主要是介绍我怎么用上一节实现的UNet进行训练,一共分成3部分进行说明。需要强调的是,本节中的数据集以及很多模型训练想法都是来自【Keras】基於SegNet和U-Net的遙感圖像語義分割,我主要的工作就是将keras的代码用pytorch进行了实现。在上面的链接里,该作者对他们设计模型以及数据处理进行了较为详细的介绍。
刚开始我自己用pytorch实现了训练的模型,但是感觉并不是很好,主要是代码的结构不喜欢,后来在github上找到了一个pytorch训练测试模型的模板,然后把代码加到模板里,训练起来确实方便了不少,原始的模板: pytorch-template。大家可以用一下试试,我个人感觉如果能自己写还是自己写,等写熟了再用模板应该会好一点。
数据集
原始数据和数据集生成的代码都放在了百度网盘,提取码:7rtr。下载以后先将数据集解压,然后运行里面的gen_dataset.py,注意修改里面的文件路径,然后使用traverse_dataset.py遍历数据集,将文件名存进txt中,用于pytorch加载数据。
模型训练
因为是用的别人的训练模板,所以需要做的事情不是很多,只需要按照别人模板说明针对性的修改文件就行了,训练的话就:
python train.py --config config.json
至于代码,可以从github上下载我修改后的代码,可以直接运行:github地址。
结果分析
训练的过程中有保存日志,我还用tensorboard看了。但是很奇怪的是,训练完以后再打开结果就只显示最初的一段训练结果,不停的刷新页面可以更新,点了半天刷新,有点累,只得到下面的图:
这个是加了norm batch的训练结果,和没加的区别不是很大,我印象中加了以后好像训练过程中loss波动要大一些,不过下降趋势都差不多,最终的Accuracy也很接近,大致就是这样。可能是状态不好,感觉写点东西,浑身难受,有点烦躁。
网友评论