1 问题描述
怎么能从剪切秧苗中分辨出杂草?
如果能够很好的实现分辨杂草可以有效提高产量和环境管理。
Aarhus大学信号处理研究小组,与Southern Denmark大学合作,发布了一个数据集。该数据集包含各个生长阶段大约960颗幼苗,它们属于12种植物。
图1若要获取数据集更多信息,可以参考original data。
论文参考A Public Image Database for Benchmark of Plant Seedling Classification Algorithms
2 数据描述
该数据集包含了不同生长阶段的植物种子的训练和测试集。每个图像是一个文件且对应一个唯一的id。数据集包括12种植物。我们的目标是创建一个分类器,实现从一张图片中识别出植物的种类,下面植物种类的列表:
Black-grass
Charlock
Cleavers
Common Chickweed
Common wheat
Fat Hen
Loose Silky-bent
Maize
Scentless Mayweed
Shepherds Purse
Small-flowered Cranesbill
Sugar beet
文件描述
train.csv - 训练集,具体的图像文件在文件夹中
test.csv - 测试集, 需要识别的图片
sample_submission.csv - 上传结果文件的格式
3 数据预处理
图片的处理主要使用opencv来做。
标签处理采用one-hot编码。
def label_img(word_label):
if word_label == 'Black-grass': return [1,0,0,0,0,0,0,0,0,0,0,0]
elif word_label == 'Charlock': return [0,1,0,0,0,0,0,0,0,0,0,0]
elif word_label == 'Cleavers': return [0,0,1,0,0,0,0,0,0,0,0,0]
elif word_label == 'Common Chickweed': return [0,0,0,1,0,0,0,0,0,0,0,0]
elif word_label == 'Common wheat': return [0,0,0,0,1,0,0,0,0,0,0,0]
elif word_label == 'Fat Hen': return [0,0,0,0,0,1,0,0,0,0,0,0]
elif word_label == 'Loose Silky-bent': return [0,0,0,0,0,0,1,0,0,0,0,0]
elif word_label == 'Maize': return [0,0,0,0,0,0,0,1,0,0,0,0]
elif word_label == 'Scentless Mayweed': return [0,0,0,0,0,0,0,0,1,0,0,0]
elif word_label == 'Shepherds Purse': return [0,0,0,0,0,0,0,0,0,1,0,0]
elif word_label == 'Small-flowered Cranesbill': return [0,0,0,0,0,0,0,0,0,0,1,0]
elif word_label == 'Sugar beet': return [0,0,0,0,0,0,0,0,0,0,0,1]
然后建立train数据集。
def create_train_data():
train = []
for category_id, category in enumerate(CATEGORIES):
for img in tqdm(os.listdir(os.path.join(train_dir, category))):
label=label_img(category)
path=os.path.join(train_dir,category,img)
img=cv2.imread(path,cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (IMG_SIZE,IMG_SIZE))
train.append([np.array(img),np.array(label)])
shuffle(train)
return train
然后对训练数据按照训练集和验证集进行划分。
4 模型
模型代码看这里。
使用深度学习调试技巧。
先简单的做了下模型训练:
- 没有使用数据增强
- 使用小批量数据优化
- 使用SGD优化算法
- 先直接算训练误差,然后计算得到训练误差和验证误差
批处理训练交叉熵为0.051953293,验证集交叉熵为0.29346284。
本来想看下准确率,但是准确率为0,我了个去!倒吸了口凉气,WTF!是什么原因使得准确率为0?
不平衡样本处理
使用sklearn的stratifiedKFold函数来实现K折交叉验证。
数据增强
5 评估
epoch为5000次:训练集准确率为0.836675。
6 总结
这里简单使用了MobileNet V1来识别种子图片。
S1: 对数据做了预处理;
S2: 建立MobileNet V1模型;
S3: 根据网络调试技巧来训练网络;
S4: 保存模型,以备后用。
网友评论