pytroch学习(二十)—训练一个性别2分类网络

作者: 侠之大者_7d3f | 来源:发表于2019-01-11 11:15 被阅读1次

    前言

    曾经参加过中兴举办的人脸属性识别大赛,大赛的要求:识别不同场景下的人脸,识别结果包括:性别,年龄,是否戴眼镜, 眼镜颜色, 是否戴口罩,口罩颜色,是否戴帽子,帽子颜色 。 当时对深度学习还不太了解,只是简单的找了一个.caffemodel,采用opencv调用。 关于口罩,眼镜的识别,当时的思路是采用人脸关键点检测SDK, 检测出五官的位置, 针对眼睛区域,嘴巴,截取一个子图像,然后转换到YCrCb、HSV颜色空间,判断颜色。

    当时直接采用别人训练好的caffemodel,具体的训练过程并不熟悉。在学习pytorch之后,打算采用pytorch先实现一下性别识别。性别识别是一个2分类问题,网上应该有不少的研究。比如商汤/旷世科技 早已经将人脸属性继承到SDK中,可以供API在线调用,还有针对Android, ios的SDK, 本人测试过,速度很精度都很不错。

    简单起见,本人直接采用预训练模型微调的方式训练一个性别分类器。


    开发/测试环境

    • Ubuntu 18.04
    • Anaconda 3, python3.6
    • pytroch 1.0
    • torchnet

    网络模型选择

    torchvision.models中集成了几个常见的网络模型,ResNet, AlexNet, VGG, DenseNet, SqueezeNet。 AlexNet和VGG模型文件都很大,AlexNet大约230M, VGG更大,下载特别慢,而且这么大的模型文件对于以后往移动平台移植很不利。

    SqueezeNet有所了解,这是一个轻量化的网络,网络名称squeeze就是压缩的意思。作者文章介绍到SqueezeNet与AlexNet精度相当,模型参数大大降低。因此决定采用SqueezeNet进行实验,如果效果不错可以考虑Android端的移植。

    SqueezeNet

    image.png image.png image.png image.png

    数据集制作

    之前在CSDN下载过一个较小的数据集合,总共有200个男性+200个女性,当时没做数据增加,直接对网络进行微调,效果比较差。这次,采用UTKFace数据集进行训练。

    UTKFace

    https://susanqq.github.io/UTKFace/

    image.png image.png image.png image.png image.png

    训练集,验证集,测试集划分

    UTKFace数据需要从Google Drive下载,链接包含2个压缩包。 本人采用第一个压缩包的数据,规模稍微小,先看看效果。训练:验证:测试 = 6:2:2

    • 总共9780张图像
    • 训练数据5000+
    • 验证数据约2000
    • 测试数据约2000

    由于UTKFace数据的按照年龄排序的,因此在划分数据时候全部采用随机采样。
    划分结果:


    image.png

    使用pytorch加载数据

    • 继承Dataset类, override __len()__, __getitem()__方法
    • 采用Dataloder包装,按照mini_batch方式读取
    from torch.utils.data import Dataset
    import torch
    import torchvision.transforms as transforms
    import PIL.Image as Image
    import os
    import numpy
    import shutil
    import random
    
    
    class UTKFaceGenderDataset(Dataset):
    
        def __init__(self, root, txt_file, transform=None, target_transform=None):
            self.root = root
            self.transform = transform
            self.target_transform = target_transform
            self.class_name = {0: 'male', 1: 'female'}
            self.txt_file = txt_file
            self.length = 0
            self.images_name = []
            f = open(txt_file, 'r')
            assert f is not None
            for i in f:
                self.length += 1
                self.images_name.append(i)
    
        def __len__(self):
            return self.length
    
        def __getitem__(self, index):
            image_name = self.images_name[index]
            # if not os.path.isfile(os.path.join(self.root, image_name)):
            #     return None
            image = Image.open(os.path.join(self.root, image_name).rstrip())
            assert image is not None
            label = int(image_name.split('_')[1])
    
            image_transformed = image
            label_transformed = label
            if self.transform:
                image_transformed = self.transform(image)
    
            if self.target_transform:
                label_transformed = self.target_transform(label)
    
            return {'image': image_transformed, 'label': label_transformed}
    
    

    DataLoader包装

    # ---------------------------数据集--------------------------------------------------
    batch_size = 8
    data_root = '/media/weipenghui/Extra/人脸属性识别/UTKFace/crop_part1/'
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor()
         ])
    
    train_dataset = dataset.UTKFaceGenderDataset(root=os.path.join(data_root, 'image'),
                                                 txt_file=os.path.join(data_root, 'train.txt'),
                                                 transform=transform)
    
    print('train_dataset: {}'.format(len(train_dataset)))
    
    train_dataloader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True, num_workers=4)
    
    

    训练


    测试

    viz = visdom.Visdom(env='test')
    GENDER = ['male', 'female']
    
    for i, sample in enumerate(test_dataloader):
        inputs, labels = sample['image'], sample['label']
    
        outputs = model(inputs)
    
        _, prediction = torch.max(outputs, 1)
        correct += (labels == prediction).sum().item()
        total += labels.size(0)
    
        inputs = make_grid(inputs)
        viz.image(inputs, opts=dict(title='{},{},{},{}'.format(GENDER[labels[0].item()],GENDER[labels[1].item()],GENDER[labels[2].item()],GENDER[labels[3].item()])))
    
    
    image.png image.png image.png image.png image.png

    输出

    image.png image.png

    完整工程

    • 数据集
    from torch.utils.data import Dataset
    import torch
    import torchvision.transforms as transforms
    import PIL.Image as Image
    import os
    import numpy
    import shutil
    import random
    
    
    class UTKFaceGenderDataset(Dataset):
    
        def __init__(self, root, txt_file, transform=None, target_transform=None):
            self.root = root
            self.transform = transform
            self.target_transform = target_transform
            self.class_name = {0: 'male', 1: 'female'}
            self.txt_file = txt_file
            self.length = 0
            self.images_name = []
            f = open(txt_file, 'r')
            assert f is not None
            for i in f:
                self.length += 1
                self.images_name.append(i)
    
        def __len__(self):
            return self.length
    
        def __getitem__(self, index):
            image_name = self.images_name[index]
            # if not os.path.isfile(os.path.join(self.root, image_name)):
            #     return None
            image = Image.open(os.path.join(self.root, image_name).rstrip())
            assert image is not None
            label = int(image_name.split('_')[1])
    
            image_transformed = image
            label_transformed = label
            if self.transform:
                image_transformed = self.transform(image)
    
            if self.target_transform:
                label_transformed = self.target_transform(label)
    
            return {'image': image_transformed, 'label': label_transformed}
    
    
    # train_file = open('/media/weipenghui/Extra/人脸属性识别/UTKFace/crop_part1/train.txt', 'w')
    # val_file = open('/media/weipenghui/Extra/人脸属性识别/UTKFace/crop_part1/val.txt', 'w')
    # test_file = open('/media/weipenghui/Extra/人脸属性识别/UTKFace/crop_part1/test.txt', 'w')
    #
    # image_idx = list(range(len(os.listdir('/media/weipenghui/Extra/人脸属性识别/UTKFace/crop_part1/image'))))
    # images_name = os.listdir('/media/weipenghui/Extra/人脸属性识别/UTKFace/crop_part1/image')
    #
    # for i in range(1956):
    #     num = 0
    #     while True:
    #         num = random.randint(a=image_idx[0], b=image_idx[len(image_idx)-1]-1)
    #         if num in image_idx:
    #             break
    #     image_name = images_name[num]
    #     test_file.write(image_name + '\n')
    #     image_idx.remove(num)
    #     print(i)
    #
    # test_file.close()
    # print('test.txt create finish!')
    #
    # for i in range(1956):
    #     num = 0
    #     while True:
    #         num = random.randint(a=image_idx[0], b=image_idx[len(image_idx)-1]-1)
    #         if num in image_idx:
    #             break
    #     image_name = images_name[num]
    #     val_file.write(image_name + '\n')
    #     image_idx.remove(num)
    #     print(i)
    #
    # test_file.close()
    # print('val.txt create finish!')
    #
    # for i in image_idx:
    #     train_file.write(images_name[i] + '\n')
    # print('train.txt create finish!')
    
    
    # ---------------------------测试--------------------------------------------------
    # length = len(os.listdir('/media/weipenghui/Extra/人脸属性识别/UTKFace/crop_part1/image'))
    #
    #
    #
    # batch_size = 8
    # data_root = '/media/weipenghui/Extra/人脸属性识别/UTKFace/crop_part1/'
    # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    #
    # transform = transforms.Compose(
    #     [transforms.Resize((224, 224)),
    #      transforms.ToTensor()
    #      ])
    #
    # train_dataset = UTKFaceGenderDataset(root=os.path.join(data_root, 'image'),
    #                                              txt_file=os.path.join(data_root, 'train.txt'),
    #                                              transform=transform,
    #                                              target_transform=ToTensor())
    #
    # print('train_dataset: {}'.format(len(train_dataset)))
    #
    #
    # val_dataset = UTKFaceGenderDataset(root=os.path.join(data_root, 'image'),
    #                                            txt_file=os.path.join(data_root, 'val.txt'),
    #                                            transform=transform,
    #                                            target_transform=ToTensor()
    #                                    )
    # print('val dataset: {}'.format(len(val_dataset)))
    #
    # datasets = [train_dataset, val_dataset]
    # for dataset in datasets:
    #     print('-'*20)
    #     for i, sample in enumerate(dataset):
    #         print('{}, {}, label={}'.format(dataset.images_name[i].rstrip(), i, sample['label'].item()))
    #         # if (sample['label'].item() == 0) or (sample['label'].item() == 1):
    #         #     continue
    #         # else:
    #         #     os.remove(os.path.join('/media/weipenghui/Extra/人脸属性识别/UTKFace/crop_part1/image',dataset.images_name[i].rstrip()))
    #         #     continue
    #         assert sample['label'].item() == 0 or sample['label'].item() == 1
    
    
    
    • 训练+验证
    import torch
    import torch.optim as optim
    import torch.nn as nn
    from torch.utils.data import DataLoader
    from torchvision.datasets import ImageFolder
    import torchvision.models as models
    import torchvision.transforms as transforms
    import numpy as np
    import copy
    import matplotlib.pyplot as plt
    import dataset
    import os
    import torchnet
    
    # ---------------------------数据集--------------------------------------------------
    batch_size = 8
    data_root = '/media/weipenghui/Extra/人脸属性识别/UTKFace/crop_part1/'
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor()
         ])
    
    train_dataset = dataset.UTKFaceGenderDataset(root=os.path.join(data_root, 'image'),
                                                 txt_file=os.path.join(data_root, 'train.txt'),
                                                 transform=transform)
    
    print('train_dataset: {}'.format(len(train_dataset)))
    
    train_dataloader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True, num_workers=4)
    
    # plt.figure()
    # for i in train_dataset:
    #     plt.imshow(np.transpose(i['image'].numpy(), (1, 2, 0)))
    #     plt.title(train_dataset.class_name[i['label']])
    #     plt.show()
    
    val_dataset = dataset.UTKFaceGenderDataset(root=os.path.join(data_root, 'image'),
                                               txt_file=os.path.join(data_root, 'val.txt'),
                                               transform=transform)
    print('val dataset: {}'.format(len(val_dataset)))
    
    val_dataloader = DataLoader(dataset=val_dataset,batch_size=batch_size, shuffle=False, num_workers=4)
    
    # ------------------定义网络---------------------------------
    # 载入预训练的型
    model = models.squeezenet1_1(pretrained=True)
    model.classifier[1] = nn.Conv2d(in_channels=512, out_channels=2, kernel_size=(1, 1), stride=(1, 1))
    model.num_classes = 2
    print(model)
    # print('Down finish')
    # model = models.alexnet(pretrained=True)
    # # 修改输出层,2分类
    # model.classifier[6] = nn.Linear(in_features=4096, out_features=2)
    model.to(device)
    
    # ------------------优化方法,损失函数--------------------------------------------------
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    loss_fc = nn.CrossEntropyLoss()
    scheduler = optim.lr_scheduler.StepLR(optimizer, 10, 0.1)
    
    # ------------------训练--------------------------------------------------------------
    num_epoch = 10
    
    # 训练日志保存
    file_train_loss = open('./log/train_loss.txt', 'w')
    file_train_acc = open('./log/train_acc.txt', 'w')
    
    file_val_loss = open('./log/val_loss.txt', 'w')
    file_val_acc = open('./log/val_acc.txt', 'w')
    
    # loss可视化
    # win_loss = torchnet.logger.VisdomPlotLogger(plot_type='line',
    #                                             env='gender_classfiy',
    #                                             opts=dict(title='Train loss'),
    #                                             win='Train loss')
    # 
    # # Accuracy可视化
    # win_acc = torchnet.logger.VisdomPlotLogger(plot_type='line',
    #                                            env='gender_classify',
    #                                            opts=dict(title='Val acc'),
    #                                            win='Val acc')
    
    acc_best_wts = model.state_dict()
    best_acc = 0
    iter_count = 0
    
    for epoch in range(num_epoch):
        train_loss = 0.0
        train_acc = 0.0
        train_correct = 0
        train_total = 0
    
        val_loss = 0.0
        val_acc = 0.0
        val_correct = 0
        val_total = 0
    
        scheduler.step()
        for i, sample_batch in enumerate(train_dataloader):
            inputs = sample_batch['image'].to(device)
            labels = sample_batch['label'].to(device)
    
            # 模型设置为train
            model.train()
    
            # forward
            outputs = model(inputs)
    
            # print(labels)
            # loss
            loss = loss_fc(outputs, labels)
    
            # forward update
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            # 统计
            train_loss += loss.item()
            train_correct += (torch.max(outputs, 1)[1] == labels).sum().item()
            train_total += labels.size(0)
    
            print('iter:{}'.format(i))
    
            if i % 200 == 199:
                for sample_batch in val_dataloader:
                    inputs = sample_batch['image'].to(device)
                    labels = sample_batch['label'].to(device)
    
                    model.eval()
                    outputs = model(inputs)
                    loss = loss_fc(outputs, labels)
                    _, prediction = torch.max(outputs, 1)
                    val_correct += ((labels == prediction).sum()).item()
                    val_total += inputs.size(0)
                    val_loss += loss.item()
    
                val_acc = val_correct / val_total
                print('[{},{}] train_loss = {:.5f} train_acc = {:.5f} val_loss = {:.5f} val_acc = {:.5f}'.format(
                    epoch + 1, i + 1, train_loss / 100,train_correct / train_total, val_loss/len(val_dataloader),
                    val_correct / val_total))
                if val_acc > best_acc:
                    best_acc = val_acc
                    acc_best_wts = copy.deepcopy(model.state_dict())
    
                file_train_loss.write(str(train_loss / 100) + '\n')
                file_train_acc.write(str(train_correct / train_total) + '\n')
                file_val_loss.write(str(val_loss/len(val_dataloader)) + '\n')
                file_val_acc.write(str(val_correct / val_total) + '\n')
    
                iter_count += 200
    
                # 可视化
                # win_loss.log(iter_count, train_loss)
                # win_acc.log(iter_count, val_acc)
    
                train_loss = 0.0
                train_total = 0
                train_correct = 0
                val_correct = 0
                val_total = 0
                val_loss = 0
    
    
    print('Train finish!')
    # 保存模型
    torch.save(acc_best_wts, './models/model_squeezenet_utk_face_1.pth')
    print('Model save ok!')
    
    • 测试
    import torch
    import torch.optim as optim
    import torch.nn as nn
    from torch.utils.data import DataLoader
    from torchvision.datasets import ImageFolder
    import torchvision.models as models
    import torchvision.transforms as transforms
    import numpy as np
    import copy
    import matplotlib.pyplot as plt
    import dataset
    import os
    
    
    data_root = '/media/weipenghui/Extra/人脸属性识别/UTKFace/crop_part1/'
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor()
         ])
    
    test_dataset = dataset.UTKFaceGenderDataset(root=os.path.join(data_root, 'image'),
                                                 txt_file=os.path.join(data_root, 'test.txt'),
                                                 transform=transform)
    
    print('test_dataset: {}'.format(len(test_dataset)))
    
    
    model = models.squeezenet1_1(pretrained=True)
    model.classifier[1] = nn.Conv2d(in_channels=512, out_channels=2, kernel_size=(1, 1), stride=(1, 1))
    model.num_classes = 2
    model.load_state_dict(torch.load('./models/model_squeezenet_utk_face_20.pth', map_location='cpu'))
    print(model)
    model.eval()
    
    
    test_dataloader = DataLoader(dataset=test_dataset, batch_size=4, shuffle=False, num_workers=4)
    
    correct = 0
    total = 0
    acc = 0.0
    for i, sample in enumerate(test_dataloader):
        inputs, labels = sample['image'], sample['label']
    
        outputs = model(inputs)
    
        _, prediction = torch.max(outputs, 1)
        correct += (labels == prediction).sum().item()
        total += labels.size(0)
    
    acc = correct / total
    print('test finish, total:{}, correct:{}, acc:{:.3f}'.format(total, correct, acc))
    
    
    
    • 解析log, 可视化Loss, Accuracy
    import numpy as np
    import visdom
    
    
    train_loss = np.loadtxt('./log/train_loss2.txt', dtype=float)
    train_acc = np.loadtxt('./log/train_acc2.txt', dtype=float)
    val_loss = np.loadtxt('./log/val_loss2.txt', dtype=float)
    val_acc = np.loadtxt('./log/val_acc2.txt', dtype=float)
    
    
    viz = visdom.Visdom(env='gender_classifier')
    viz.line(Y=train_loss, win='train_loss', opts=dict(title='train_loss'))
    viz.line(Y=val_loss, win='val_loss', opts=dict(title='val_loss'))
    viz.line(Y=train_acc, win='train_acc', opts=dict(title='train_acc'))
    viz.line(Y=val_acc, win='val_acc', opts=dict(title='val_acc'))
    
    

    相关文章

      网友评论

        本文标题:pytroch学习(二十)—训练一个性别2分类网络

        本文链接:https://www.haomeiwen.com/subject/emitdqtx.html