美文网首页
4. pytorch-简单分类

4. pytorch-简单分类

作者: FantDing | 来源:发表于2018-07-01 08:52 被阅读0次
    import torch
    import matplotlib.pyplot as plt
    import torch.nn.functional as F
    
    
    class Net(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.h1 = torch.nn.Linear(2, 10)
            # self.a1=torch.nn.ReLU()
            self.h2 = torch.nn.Linear(10, 2)
    
        def forward(self, x):
            # x=self.a1(self.h1(x)) # 使用这种写法,在print(net)时,会显示relu层
            x = F.relu(self.h1(x))  # 这种写法不会显示relu层,对于无状态的函数,推荐使用F形式(简单)
            x = self.h2(x)
            return x
    
    
    if __name__ == "__main__":
        # test()
        # 1. 数据准备
        x0 = torch.normal(torch.ones(100, 2) * 1, 1)
        y0 = torch.zeros(100, 1)
        x1 = torch.normal(torch.ones(100, 2) * -2, 1)
        y1 = torch.ones(100, 1)
    
        x = torch.cat((x0, x1), dim=0)
        y = torch.cat((y0, y1), dim=0).long().squeeze()
        # 可视化数据
        # plt.scatter(x0.numpy()[:, 0], x0.numpy()[:, 1], c="red", label="negtive")
        # plt.scatter(x1.numpy()[:, 0], x1.numpy()[:, 1], c="green", label="positive")
        # plt.legend()
        # plt.show()
    
        # 2. 定义网络
        net = Net()
        # 3. 训练
        optimizer = torch.optim.SGD(net.parameters(), lr=0.02)
        loss_F = torch.nn.CrossEntropyLoss()
        for iter in range(100):
            pred = net(x)
            # arg1: 二维的原始输出(没有加softmax)
            # arg2: 一维的batch_size的真实cls label(不是二维one hot编码)
            loss = loss_F(pred, y)
    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            plt.ion()
            # 计算精度
            if iter % 2 == 0:
                print("*" * 8, "iter:", iter, "*" * 8)
                loss_str = "loss: {:.4f}".format(loss.data.numpy())
                print(loss_str)
                probability = F.softmax(pred, dim=1)
                pred_cls = torch.argmax(probability, dim=1)
                equal = (pred_cls == y)  # 返回的torch元素为0或1,不是bool类型
                accuracy = torch.sum(equal).data.numpy() / 200
                print("accuracy:", accuracy)
    
                # 画图
                plt.cla()
                plt.scatter(x.numpy()[:, 0], x.numpy()[:, 1], c=pred_cls)
                plt.text(1, -3.5, "accuracy:{}".format(accuracy))
                plt.pause(0.2)
    
        plt.ioff()
        plt.show()
    
    image.png

    相关文章

      网友评论

          本文标题:4. pytorch-简单分类

          本文链接:https://www.haomeiwen.com/subject/khkcuftx.html