美文网首页
ArcFaceLoss和CenterLoss的代码实现

ArcFaceLoss和CenterLoss的代码实现

作者: 星光下的胖子 | 来源:发表于2021-05-13 16:42 被阅读0次

    手动实现ArcFaceLoss和CenterLoss,并用来训练MNIST数据。

    导入相关库
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torchvision import datasets, transforms as T
    from torch.utils.data import DataLoader
    import itertools
    import matplotlib.pyplot as plt
    # 查看时间和进度
    from tqdm import tqdm
    import time
    
    实现ArcFaceNet和CenterLossNet
    class ArcFaceNet(nn.Module):
        def __init__(self, cls_num=10, feature_dim=2):
            super(ArcFaceNet, self).__init__()
            self.w = nn.Parameter(torch.randn(feature_dim, cls_num))
    
        def forward(self, features, m=1, s=10):
            # 特征与权重 归一化
            _features = nn.functional.normalize(features, dim=1)
            _w = nn.functional.normalize(self.w, dim=0)
    
            # 特征向量与参数向量的夹角theta,分子numerator,分母denominator
            theta = torch.acos(torch.matmul(_features, _w) / 10)  # /10防止下溢
            numerator = torch.exp(s * torch.cos(theta + m))
            denominator = torch.sum(torch.exp(s * torch.cos(theta)), dim=1, keepdim=True) - torch.exp(
                s * torch.cos(theta)) + numerator
            return torch.log(torch.div(numerator, denominator))
        
    class CenterLossNet(nn.Module):
        def __init__(self, cls_num=10, feature_dim=2):
            super(CenterLossNet, self).__init__()
            self.centers = nn.Parameter(torch.randn(cls_num, feature_dim))
    
        def forward(self, features, labels, reduction='mean'):
            # 特征向量归一化
            _features = nn.functional.normalize(features)
    
            centers_batch = self.centers.index_select(dim=0, index=labels.long())
            # 根据论文《A Discriminative Feature Learning Approach for Deep Face Recognition》修改如下
            if reduction == 'sum':  # 返回loss的和
                return torch.sum(torch.pow(_features - centers_batch, 2)) / 2
            elif reduction == 'mean':  # 返回loss和的平均值,默认为mean方式
                return torch.sum(torch.pow(_features - centers_batch, 2)) / 2 / len(features)
            else:
                raise ValueError("ValueError: {0} is not a valid value for reduction".format(reduction))
    
    定义LeNet模型
    class LeNet(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = nn.Sequential(
                nn.Conv2d(1, 64, 3, padding=1),
                nn.PReLU(),
                nn.BatchNorm2d(64),
                nn.Conv2d(64, 32, 3, stride=2, padding=1),
                nn.PReLU(),
                nn.BatchNorm2d(32),
                nn.modules.Flatten()
            )
            self.linear = nn.Sequential(
                nn.Linear(32 * 14 * 14, 512),
                nn.PReLU(),
                nn.BatchNorm1d(512),
                nn.Linear(512, 256),
                nn.PReLU(),
                nn.BatchNorm1d(256),
                nn.Linear(256, 64),
                nn.PReLU(),
                nn.BatchNorm1d(64),
                nn.Linear(64, 32)
                # nn.Linear(64, 2, bias=False)  # features设置为二维,可以进行可视化
            )
            self.out_layer = nn.Sequential(
                nn.Linear(32, 10),
                # nn.Linear(2, 10),  # features设置为二维,可以进行可视化
                nn.LogSoftmax(dim=1)  # LogSoftmax与net=nn.NLLLoss()结合使用,求交叉熵损失
            )
    
        def forward(self, x):
            features = self.linear(self.conv(x))
            out = self.out_layer(features)  # 用于计算CrossEntropyLoss
            return features, out
    
    模型训练

    两种损失计算方式:

    • CrossEntropyLoss+CenterLoss
    • ArcFaceLoss+CenterLoss

    超参数都是初始随便设定的,跑了一遍,精度可达到99.29。你可以调调超参数,精度可以更高。训练代码如下:

    # 特征向量可视化
    def visualize(features, labels, loss, epoch):
        # 定义10种颜色
        colors = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff', '#ff00ff', '#990000', '#999900', '#009900',
                  '#009999']
    
        plt.clf()  # 清空画板
        # 画出所有的点,不同的label对应不同的颜色
        for i in range(10):
            plt.plot(features[labels == i, 0], features[labels == i, 1], ".", c=colors[i], label=i)
        plt.legend(loc="upper right")  # 图例
        plt.title(f"ce+cl: epoch={epoch}, loss={loss}")  # 标题
        plt.savefig("ce+cl/image/epoch%d.jpg" % epoch)  # 保存图片
        plt.draw()  # 展示图片
        plt.pause(0.001)
    
    # 1.加载数据集
    transform_op = T.Compose([  # 数据预处理
        T.ToTensor(),
        T.Normalize([0.4914], [0.2023])
    ])
    train_dataset = datasets.MNIST("../code/data", train=True, transform=transform_op, download=False)
    val_dataset = datasets.MNIST("../code/data", train=False, transform=transform_op, download=False)
    train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=128, shuffle=False)
    
    # 训练设备: GPU or CPU
    device= torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # 超参
    lr = 1e-3
    epochs = 20
    lr_schedule = {
        5: 1e-3,
        10: 1e-4,
        15: 1e-5
    }
    alpha = 0.95  # centerloss与arcfaceloss的权重比例
    

    1.CrossEntropyLoss+CenterLoss

    # 2.创建模型
    cls_num, feature_dim = 10, 32  # 10分类
    # cls_num, feature_dim = 10, 2  # features设置为二维,可以进行可视化
    net = LeNet().to(device)
    centerloss_net = CenterLossNet(cls_num, feature_dim).to(device)
    # 3.定义损失
    loss_func = nn.NLLLoss()
    # 4.定义优化器
    optimizer = optim.Adam(itertools.chain(net.parameters(), centerloss_net.parameters()), lr)
    
    # 5.模型训练
    plt.ion()
    for epoch in range(epochs):
        start = time.time()
        # 学习率策略
        if epoch in lr_schedule:
            lr = lr_schedule[epoch]
            for group in optimizer.param_groups:
                group["lr"] = lr
                
        # 1)训练集
        net.train()  # train mode
        features_loader, labels_loader = [], []  # 保存特征向量和标签的列表,用于可视化操作
        train_loss = 0.
        for images, targets in tqdm(train_dataloader):
            images, targets = images.to(device), targets.to(device)
            
            # 方式1: CrossEntropyLoss+CenterLoss
            features, out = net(images)
            # 计算损失
            ce_loss = loss_func(out, targets)
            center_loss = centerloss_net(features, targets)
            loss = alpha * ce_loss + (1 - alpha) * center_loss
            
            optimizer.zero_grad()  # 清空梯度
            loss.backward()  # 反向传播
            optimizer.step()  # 梯度更新
            
            # 统计训练损失
            train_loss += loss.cpu().detach().item()
    
            # 将特征和标签加入到列表中
            features_loader.append(features)
            labels_loader.append(targets)
        # 计算平均损失
        train_loss /= len(train_dataloader)
    
        # 2.测试集
        net.eval()  # evaluation mode
        val_loss, correct = 0., 0.
        with torch.no_grad():  # 作用域范围内不计算梯度,节省内存
            for images, targets in tqdm(val_dataloader):
                images, targets = images.to(device), targets.to(device)
                
                # 方式1: CrossEntropyLoss+CenterLoss
                features, out = net(images)
                # 计算损失
                ce_loss = loss_func(out, targets)
                center_loss = centerloss_net(features, targets)
                loss = alpha * ce_loss + (1 - alpha) * center_loss
    
                # 统计验证损失
                val_loss += loss.cpu().detach().item()
                # 统计正确的个数
                correct += sum(out.argmax(1) == targets)
            # 计算平均损失
            val_loss /= len(val_dataloader)
        # 计算准确率
        accuracy = correct.item() / len(val_dataset)
        
        # 打印损失和精度信息
        print(f"Epoch: {epoch}/{epochs}, Train_loss: {train_loss:.5f}, Val_loss: {val_loss:.5f}, Accuracy: {accuracy}")
        # 保存模型参数
        torch.save(net.state_dict(), f"ce+cl/checkpoint/net.pt")
        torch.save(centerloss_net.state_dict(), f"ce+cl/checkpoint/centerloss_net.pt")
        # 特征向量可视化
        features = torch.cat(features_loader, dim=0)
        labels = torch.cat(labels_loader, dim=0)
        visualize(features.cpu().detach().numpy(), labels.cpu().detach().numpy(), train_loss, epoch)
        # 查看时间和进度
        end = time.time()  # 本次轮询结束时间
        print(f"第{epoch}次轮询,共耗时{int(end - start)}秒")
        time.sleep(0.01)
    plt.ioff()
    

    2.ArcFaceLoss+CenterLoss

    # 2.创建模型
    cls_num, feature_dim = 10, 32  # 10分类
    # cls_num, feature_dim = 10, 2  # features设置为二维,可以进行可视化
    net = LeNet().to(device)
    arcface_net = ArcFaceNet(cls_num, feature_dim).to(device)
    centerloss_net = CenterLossNet(cls_num, feature_dim).to(device)
    # 3.定义损失
    loss_func = nn.NLLLoss()
    # 4.定义优化器
    optimizer = optim.Adam(itertools.chain(net.parameters(), arcface_net.parameters(), centerloss_net.parameters()), lr)
    
    # 5.模型训练
    plt.ion()
    for epoch in range(epochs):
        start = time.time()
        # 学习率策略
        if epoch in lr_schedule:
            lr = lr_schedule[epoch]
            for group in optimizer.param_groups:
                group["lr"] = lr
                
        # 1)训练集
        net.train()  # train mode
        features_loader, labels_loader = [], []  # 保存特征向量和标签的列表,用于可视化操作
        train_loss = 0.
        for images, targets in tqdm(train_dataloader):
            images, targets = images.to(device), targets.to(device)
    
            # 方式2: ArcFaceLoss+CenterLoss
            features, _ = net(images)
            out = arcface_net(features)
            # 计算损失
            arcface_loss = loss_func(out, targets)  # arcfaceloss
            center_loss = centerloss_net(features, targets)  # centerloss
            loss = alpha * arcface_loss + (1 - alpha) * center_loss
            
            optimizer.zero_grad()  # 清空梯度
            loss.backward()  # 反向传播
            optimizer.step()  # 梯度更新
            
            # 统计训练损失
            train_loss += loss.cpu().detach().item()
    
            # 将特征和标签加入到列表中
            features_loader.append(features)
            labels_loader.append(targets)
        # 计算平均损失
        train_loss /= len(train_dataloader)
    
        # 2.测试集
        net.eval()  # evaluation mode
        val_loss, correct = 0., 0.
        with torch.no_grad():  # 作用域范围内不计算梯度,节省内存
            for images, targets in tqdm(val_dataloader):
                images, targets = images.to(device), targets.to(device)
    
                # 方式2: ArcFaceLoss+CenterLoss
                features, _ = net(images)
                out = arcface_net(features)
                # 计算损失
                arcface_loss = loss_func(out, targets)  # arcfaceloss
                center_loss = centerloss_net(features, targets)  # centerloss
                loss = alpha * arcface_loss + (1 - alpha) * center_loss
    
                # 统计验证损失
                val_loss += loss.cpu().detach().item()
                # 统计正确的个数
                correct += sum(out.argmax(1) == targets)
            # 计算平均损失
            val_loss /= len(val_dataloader)
        # 计算准确率
        accuracy = correct.item() / len(val_dataset)
        
        # 打印损失和精度信息
        print(alpha * arcface_loss, (1 - alpha) * center_loss, arcface_loss, center_loss)
        print(f"Epoch: {epoch}/{epochs}, Train_loss: {train_loss:.5f}, Val_loss: {val_loss:.5f}, Accuracy: {accuracy}")
        # 保存模型参数
        torch.save(net.state_dict(), f"arcface+cl/checkpoint/net.pt")
        torch.save(centerloss_net.state_dict(), f"arcface+cl/checkpoint/centerloss_net.pt")
        torch.save(arcface_net.state_dict(), f"arcface+cl/checkpoint/arcface_net.pt")
        # 特征向量可视化
        features = torch.cat(features_loader, dim=0)
        labels = torch.cat(labels_loader, dim=0)
        visualize(features.cpu().detach().numpy(), labels.cpu().detach().numpy(), epoch, train_loss, val_loss, accuracy)
        # 查看时间和进度
        end = time.time()  # 本次轮询结束时间
        print(f"第{epoch}次轮询,共耗时{int(end - start)}秒")
        time.sleep(0.01)
    plt.ioff()
    

    相关文章

      网友评论

          本文标题:ArcFaceLoss和CenterLoss的代码实现

          本文链接:https://www.haomeiwen.com/subject/cgbadltx.html