美文网首页
3.3.分类问题

3.3.分类问题

作者: BlueFishMan | 来源:发表于2018-07-11 09:52 被阅读0次
    # Logistic回归
    import torch
    from torch.autograd import Variable
    import torch.nn as nn
    
    with open('iris.csv','r') as f:
        data_list=f.readlines()
        data_list=[i.split('\n')[0] for i in data_list]
        data_list=[i.split(',') for i in data_list]
        x_data=[(float(i[0]),float(i[1]),float(i[2]),float(i[3])) for i in data_list]
        y_data=[(float(i[4])) for i in data_list]
        x_data=torch.Tensor(x_data)
        y_data=torch.Tensor(y_data).unsqueeze(1)
        if torch.cuda.is_available():
            x=Variable(x_data).cuda()
            y=Variable(y_data).cuda()
        else:
            x=Variable(x_data)
            y=Variable(y_data)
    
    #定义Logistic回归的模型,以及二分类问题的损失函数和优化方法
    class LogisticRegression(nn.Module):
        def __init__(self):
            super(LogisticRegression,self).__init__()
            self.lr=nn.Linear(4,1)
            self.sm=nn.Sigmoid()
        
        def forward(self,x):
            x=self.lr(x)
            x=self.sm(x)
            return x
    if torch.cuda.is_available():
        logistic_model=LogisticRegression().cuda()
    else:
        logistic_model=LogisticRegression()
        
    criterion=nn.BCELoss()#二分类的损失函数
    optimizer=torch.optim.SGD(logistic_model.parameters(),lr=1e-3,momentum=0.9)
    
    #训练模型
    for epoch in range(50000):
        #forward
        out=logistic_model(x)
        loss=criterion(out,y)
        print_loss=loss.data[0]
        #判断输出结果如果大于0.5就等于1,小于0.5就等于0
        mask=out.ge(0.5).float()
        correct=(mask==y).sum()
        acc=correct.data[0]/x.size(0)
        #backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if(epoch+1)%1000==0:
            print('*'*10)
            print('epoch{}'.format(epoch+1))
            print('loss is {:.4f}'.format(print_loss))
            print('acc is {:.4f}'.format(acc))
    
    **********
    epoch1000
    loss is 0.0667
    acc is 1.0000
    **********
    epoch2000
    loss is 0.0349
    acc is 1.0000
    **********
    epoch3000
    loss is 0.0239
    acc is 1.0000
    **********
    epoch4000
    loss is 0.0183
    acc is 1.0000
    **********
    epoch5000
    loss is 0.0149
    acc is 1.0000
    **********
    epoch6000
    loss is 0.0125
    acc is 1.0000
    **********
    epoch7000
    loss is 0.0109
    acc is 1.0000
    **********
    epoch8000
    loss is 0.0096
    acc is 1.0000
    **********
    epoch9000
    loss is 0.0086
    acc is 1.0000
    **********
    epoch10000
    loss is 0.0078
    acc is 1.0000
    **********
    epoch11000
    loss is 0.0072
    acc is 1.0000
    **********
    epoch12000
    loss is 0.0066
    acc is 1.0000
    **********
    epoch13000
    loss is 0.0062
    acc is 1.0000
    **********
    epoch14000
    loss is 0.0058
    acc is 1.0000
    **********
    epoch15000
    loss is 0.0054
    acc is 1.0000
    **********
    epoch16000
    loss is 0.0051
    acc is 1.0000
    **********
    epoch17000
    loss is 0.0048
    acc is 1.0000
    **********
    epoch18000
    loss is 0.0046
    acc is 1.0000
    **********
    epoch19000
    loss is 0.0044
    acc is 1.0000
    **********
    epoch20000
    loss is 0.0042
    acc is 1.0000
    **********
    epoch21000
    loss is 0.0040
    acc is 1.0000
    **********
    epoch22000
    loss is 0.0038
    acc is 1.0000
    **********
    epoch23000
    loss is 0.0037
    acc is 1.0000
    **********
    epoch24000
    loss is 0.0035
    acc is 1.0000
    **********
    epoch25000
    loss is 0.0034
    acc is 1.0000
    **********
    epoch26000
    loss is 0.0033
    acc is 1.0000
    **********
    epoch27000
    loss is 0.0032
    acc is 1.0000
    **********
    epoch28000
    loss is 0.0031
    acc is 1.0000
    **********
    epoch29000
    loss is 0.0030
    acc is 1.0000
    **********
    epoch30000
    loss is 0.0029
    acc is 1.0000
    **********
    epoch31000
    loss is 0.0028
    acc is 1.0000
    **********
    epoch32000
    loss is 0.0027
    acc is 1.0000
    **********
    epoch33000
    loss is 0.0026
    acc is 1.0000
    **********
    epoch34000
    loss is 0.0026
    acc is 1.0000
    **********
    epoch35000
    loss is 0.0025
    acc is 1.0000
    **********
    epoch36000
    loss is 0.0024
    acc is 1.0000
    **********
    epoch37000
    loss is 0.0024
    acc is 1.0000
    **********
    epoch38000
    loss is 0.0023
    acc is 1.0000
    **********
    epoch39000
    loss is 0.0023
    acc is 1.0000
    **********
    epoch40000
    loss is 0.0022
    acc is 1.0000
    **********
    epoch41000
    loss is 0.0022
    acc is 1.0000
    **********
    epoch42000
    loss is 0.0021
    acc is 1.0000
    **********
    epoch43000
    loss is 0.0021
    acc is 1.0000
    **********
    epoch44000
    loss is 0.0020
    acc is 1.0000
    **********
    epoch45000
    loss is 0.0020
    acc is 1.0000
    **********
    epoch46000
    loss is 0.0020
    acc is 1.0000
    **********
    epoch47000
    loss is 0.0019
    acc is 1.0000
    **********
    epoch48000
    loss is 0.0019
    acc is 1.0000
    **********
    epoch49000
    loss is 0.0018
    acc is 1.0000
    **********
    epoch50000
    loss is 0.0018
    acc is 1.0000

    相关文章

      网友评论

          本文标题:3.3.分类问题

          本文链接:https://www.haomeiwen.com/subject/hmwcpftx.html