美文网首页
基于CNN实现垃圾分类案例

基于CNN实现垃圾分类案例

作者: 91160e77b9d6 | 来源:发表于2020-03-16 12:43 被阅读0次

    来源:Pinecone628 - kesci.com
    原文链接:基于CNN实现垃圾分类
    点击以上链接👆 不用配置环境,直接在线运行
    数据集下载:另一个垃圾分类数据集,更多的生活垃圾图片

    1.介绍

    上海开始施行垃圾分类快两周啦。那么我们能不能通过平常学习的机器学习和深度学习的算法来实现一个简单的垃圾分类的模型呢?
    下面主要用过CNN来实现垃圾的分类。在本数据集中,垃圾的种类有六种(和上海的标准不一样),分为玻璃、纸、硬纸板、塑料、金属、一般垃圾。
    本文才有Keras来实现。

    2.导入包和数据

    import numpy as np
    import matplotlib.pyplot as plt
    from keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array, array_to_img
    from keras.layers import Conv2D, Flatten, MaxPooling2D, Dense
    from keras.models import Sequential
    
    import glob, os, random
    
    Using TensorFlow backend.
    
    base_path = '../input/trash_div7612/dataset-resized'
    
    img_list = glob.glob(os.path.join(base_path, '*/*.jpg'))
    

    我们总共有2527张图片。我们随机展示其中的6张图片。

    print(len(img_list))
    
    2527
    
    for i, img_path in enumerate(random.sample(img_list, 6)):
        img = load_img(img_path)
        img = img_to_array(img, dtype=np.uint8)
        
        plt.subplot(2, 3, i+1)
        plt.imshow(img.squeeze())
    

    3.对数据进行分组

    train_datagen = ImageDataGenerator(
        rescale=1./225, shear_range=0.1, zoom_range=0.1,
        width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=True,
        vertical_flip=True, validation_split=0.1)
    
    test_datagen = ImageDataGenerator(
        rescale=1./255, validation_split=0.1)
        
    train_generator = train_datagen.flow_from_directory(
        base_path, target_size=(300, 300), batch_size=16,
        class_mode='categorical', subset='training', seed=0)
    
    validation_generator = test_datagen.flow_from_directory(
        base_path, target_size=(300, 300), batch_size=16,
        class_mode='categorical', subset='validation', seed=0)
    
    labels = (train_generator.class_indices)
    labels = dict((v,k) for k,v in labels.items())
    
    print(labels)
    
    
    
    Found 2276 images belonging to 6 classes.
    Found 251 images belonging to 6 classes.
    {0: 'cardboard', 1: 'glass', 2: 'metal', 3: 'paper', 4: 'plastic', 5: 'trash'}
    

    4.模型的建立和训练

    model = Sequential([
        Conv2D(filters=32, kernel_size=3, padding='same', activation='relu', input_shape=(300, 300, 3)),
        MaxPooling2D(pool_size=2),
    
        Conv2D(filters=64, kernel_size=3, padding='same', activation='relu'),
        MaxPooling2D(pool_size=2),
        
        Conv2D(filters=32, kernel_size=3, padding='same', activation='relu'),
        MaxPooling2D(pool_size=2),
        
        Conv2D(filters=32, kernel_size=3, padding='same', activation='relu'),
        MaxPooling2D(pool_size=2),
    
        Flatten(),
    
        Dense(64, activation='relu'),
    
        Dense(6, activation='softmax')
    ])
    
    
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['acc'])
    
    model.fit_generator(train_generator, epochs=100, steps_per_epoch=2276//32,validation_data=validation_generator,
                        validation_steps=251//32)
    
    Epoch 1/100
    71/71 [==============================] - 29s 404ms/step - loss: 1.7330 - acc: 0.2236 - val_loss: 1.6778 - val_acc: 0.3393
    Epoch 2/100
    71/71 [==============================] - 25s 359ms/step - loss: 1.5247 - acc: 0.3415 - val_loss: 1.4649 - val_acc: 0.3750
    Epoch 3/100
    71/71 [==============================] - 24s 344ms/step - loss: 1.4455 - acc: 0.4006 - val_loss: 1.4694 - val_acc: 0.3832
    ...
    Epoch 98/100
    71/71 [==============================] - 24s 335ms/step - loss: 0.3936 - acc: 0.8583 - val_loss: 0.7845 - val_acc: 0.7321
    Epoch 99/100
    71/71 [==============================] - 24s 332ms/step - loss: 0.4013 - acc: 0.8503 - val_loss: 0.6881 - val_acc: 0.7664
    Epoch 100/100
    71/71 [==============================] - 24s 335ms/step - loss: 0.3275 - acc: 0.8768 - val_loss: 0.9691 - val_acc: 0.6696
    

    5.结果展示

    下面我们随机抽取validation中的16张图片,展示图片以及其标签,并且给予我们的预测。
    我们发现预测的准确度还是蛮高的,对于大部分图片,都能识别出其类别。

    test_x, test_y = validation_generator.__getitem__(1)
    
    preds = model.predict(test_x)
    
    plt.figure(figsize=(16, 16))
    for i in range(16):
        plt.subplot(4, 4, i+1)
        plt.title('pred:%s / truth:%s' % (labels[np.argmax(preds[i])], labels[np.argmax(test_y[i])]))
        plt.imshow(test_x[i])
    

    相关文章

      网友评论

          本文标题:基于CNN实现垃圾分类案例

          本文链接:https://www.haomeiwen.com/subject/bohgehtx.html