美文网首页AI学习
Pytorch多标签CNN端到端验证码识别

Pytorch多标签CNN端到端验证码识别

作者: 阳光树林 | 来源:发表于2018-01-04 17:32 被阅读3046次

    这其实是一个多标签分类问题,每个验证码图片有4个字符(标签),并且顺序固定;只要将卷积神经网络的最后一层稍加修改就能实现多标签分类。

    如下图所示,我们的验证码一共有4个数字,将4个数字转换成40位one_hot形式,输出层的[0-9]输出值对应第一个字符的onehot编码,[10-19]输出值对应第二个字符的onehot编码,[20-29]输出值对应第三个字符,[30-39]输出值对于第四个字符,并使用pytorch的多标签分类函数nn.MultiLabelSoftMarginLoss作为损失函数。

    image.png

    训练集800张图片,测试集200张,每张图片大小20*60

    模式结构:
    CNN (
    (conv1): Sequential (
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU ()
    (2): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
    )
    (conv2): Sequential (
    (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU ()
    (2): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
    )
    (out): Linear (624 -> 40)
    )

    # coding: utf-8
    import torch
    import torch.nn as nn
    from torch.autograd import Variable
    import torch.nn.functional as F
    import numpy as np
    import math
    import csv
    import cv2
    
    #读取标签
    csvfile = open('GenPics/label.csv')
    reader = csv.reader(csvfile)
    lables = []
    for line in reader:
        tmpLine = [line[0],line[1]]
        lables.append(tmpLine)
    csvfile.close()
    
    X = []
    y = []
    
    #读入图片
    picnum = len(lables)
    print("picnum : ", picnum)
    for i in range(0, picnum):
        img_name = "GenPics/" + lables[i][0] + '.jpg'
        img = cv2.imread(img_name, cv2.IMREAD_GRAYSCALE)
        X.append(img)  
        y.append(lables[i][1])
       
    tmp = []
    for i in range(len(y)):
        c0 = int(y[i][0])
        c1 = int(y[i][1])
        c2 = int(y[i][2])
        c3 = int(y[i][3])
        tmp.append(c0)
        tmp.append(c1)
        tmp.append(c2)
        tmp.append(c3)
    
    #处理成one_hot形式
    X = np.array(X)
    X = torch.from_numpy(X)
    X = torch.unsqueeze(X, dim=1)
    X = X.type(torch.FloatTensor)/255.
    batch_size = 4000
    yt = torch.LongTensor(tmp)
    yt = torch.unsqueeze(yt, 1)
    yt_onehot = torch.FloatTensor(batch_size, 10)
    yt_onehot.zero_()
    yt_onehot.scatter_(1, yt, 1)
    yt_onehot = yt_onehot.view(-1, 40)
    y = yt_onehot
    
    #划分训练集和测试集
    train_x = X[:800]
    train_y = y[:800]
    test_x = X[800:]
    test_x = Variable(test_x, volatile=True)
    test_y = y[800:]
    
    #定义模型
    class CNN(nn.Module):
        def __init__(self):
            super(CNN, self).__init__()
            self.conv1 = nn.Sequential(
                                nn.Conv2d(
                                    in_channels=1,
                                    out_channels=32,
                                    kernel_size=3,
                                    stride=1,
                                    padding=0,                              
                                        ),
                                nn.ReLU(),
                                nn.MaxPool2d(kernel_size=2),
                                    )
            self.conv2 = nn.Sequential(
                                nn.Conv2d(32, 16, 3, 1, 0),
                                nn.ReLU(),
                                nn.MaxPool2d(2),
                              
                            )
            self.out = nn.Linear(16*3*13, 40)
      
        def forward(self, x):
            x = self.conv1(x)
            x = self.conv2(x)
            x = x.view(x.size(0), -1)
            output = self.out(x)   
            return output
    cnn = CNN()
    print(cnn)
    # CNN (
      # (conv1): Sequential (
        # (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
        # (1): ReLU ()
        # (2): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
      # )
      # (conv2): Sequential (
        # (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1))
        # (1): ReLU ()
        # (2): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
      # )
      # (out): Linear (624 -> 40)
    # )
    
    #定义优化模型和损失函数
    batsize = 8
    epochs = 10
    optimizer = torch.optim.Adam(cnn.parameters(), lr=0.001)
    loss_func = nn.MultiLabelSoftMarginLoss()
    
    #进行迭代训练
    for epoch in range(epochs):
        losses = []
        iters = int(math.ceil(train_x.shape[0]/batsize))
        for i in range(iters):
            train_x_i = train_x[i*batsize: (i+1)*batsize]
            train_y_i = train_y[i*batsize: (i+1)*batsize]
            tx = Variable(train_x_i)
            ty = Variable(train_y_i)
            out = cnn(tx)
            loss = loss_func(out, ty)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()  
          
            losses.append(loss.data.mean())
        print('[%d/%d] Loss: %.3f' % (epoch+1, epochs, np.mean(losses)))
    # [1/10] Loss: 0.352
    # [2/10] Loss: 0.322
    # [3/10] Loss: 0.244
    # [4/10] Loss: 0.100
    # [5/10] Loss: 0.053
    # [6/10] Loss: 0.040
    # [7/10] Loss: 0.035
    # [8/10] Loss: 0.031
    # [9/10] Loss: 0.028
    # [10/10] Loss: 0.026
    
    #测试集验证准确率
    test_output = cnn(test_x)
    correct_num = 0
    for i in range(test_output.size()[0]):
        c0 = np.argmax(test_output[i, 0:10].data.numpy())
        c1 = np.argmax(test_output[i, 10:20].data.numpy())
        c2 = np.argmax(test_output[i, 20:30].data.numpy())
        c3 = np.argmax(test_output[i, 30:40].data.numpy())
        c = '%s%s%s%s' % (c0, c1, c2, c3)
        if c == lables[800+i][1]:
            correct_num += 1
    print("Test accurate :", float(correct_num)/ len(test_output))
    # Test accurate : 0.98
    
    #单个图片验证
    img_path = 'test2.jpg'
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    imgArr = np.array(img)
    imgArr = np.expand_dims(imgArr, axis=0)
    imgArr = torch.from_numpy(imgArr)
    imgArr = torch.unsqueeze(imgArr, dim=1)
    imgArr = imgArr.type(torch.FloatTensor)/255.
    imgArr = Variable(imgArr, volatile=True)
    pred_img = cnn(imgArr)
    c0 = np.argmax(pred_img[0, 0:10].data.numpy())
    c1 = np.argmax(pred_img[0, 10:20].data.numpy())
    c2 = np.argmax(pred_img[0, 20:30].data.numpy())
    c3 = np.argmax(pred_img[0, 30:40].data.numpy())
    c = '%s%s%s%s' % (c0, c1, c2, c3)
    print(c)
    # 5955
    import matplotlib.pyplot as plt
    img = plt.imread(img_path)
    plt.imshow(img)
    plt.show()
    
    image.png

    参考引用:https://github.com/junliangliu/captcha

    相关文章

      网友评论

        本文标题:Pytorch多标签CNN端到端验证码识别

        本文链接:https://www.haomeiwen.com/subject/yvornxtx.html