美文网首页AI成长社
pytorch做二分类,多分类以及回归任务

pytorch做二分类,多分类以及回归任务

作者: 文一休 | 来源:发表于2019-09-26 14:26 被阅读0次

【lightgbm/xgboost/nn代码整理四】pytorch做二分类,多分类以及回归任务

1.简介

本不打算整理pytorch代码,因为在数据挖掘类比赛中没有用过它,做图像相关任务时用pytorch比较多。有个小哥提到让整理一下,就花了几天时间整理了一份,没有很仔细调试过,有问题请读者指出。下面将从数据处理、网络搭建和模型训练三个部分介绍。如果只是想要阅读代码,可直接移步到尾部链接。

2. 数据处理

参考上一节的数据处理

3.模型

pytorch 定义的mlp代码如下:

class MLP(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output, dropout=0.5):
        super(MLP, self).__init__()
        self.dropout = torch.nn.Dropout(dropout)

        self.hidden_1 = torch.nn.Linear(n_feature, n_hidden)  # hidden layer
        self.bn1 = torch.nn.BatchNorm1d(n_hidden)

        self.hidden_2 = torch.nn.Linear(n_hidden, n_hidden//2)
        self.bn2 = torch.nn.BatchNorm1d(n_hidden//2)

        self.hidden_3 = torch.nn.Linear(n_hidden//2, n_hidden//4)  # hidden layer
        self.bn3 = torch.nn.BatchNorm1d(n_hidden//4)

        self.hidden_4 = torch.nn.Linear(n_hidden // 4, n_hidden // 8)  # hidden layer
        self.bn4 = torch.nn.BatchNorm1d(n_hidden // 8)

        self.out = torch.nn.Linear(n_hidden//8, n_output)  # output layer

    def forward(self, x):
        x = F.relu(self.hidden_1(x))  # activation function for hidden layer
        x = self.dropout(self.bn1(x))
        x = F.relu(self.hidden_2(x))  # activation function for hidden layer
        x = self.dropout(self.bn2(x))
        x = F.relu(self.hidden_3(x))  # activation function for hidden layer
        x = self.dropout(self.bn3(x))
        x = F.relu(self.hidden_4(x))  # activation function for hidden layer
        x = self.dropout(self.bn4(x))
        x = self.out(x)
        return x

定义的网路结构和上一节keras中定义的一样,同样也添加了dropout层和bn层。不同之处这个网络最终的输出都是线性输出。

训练和预测

4.1 数据加载

pytorch是以tensor的形式加载数据,需要将数据转为tenser格式,如果有gpu处理器,并且安装的也是gpu版本的pytorch,就可以使用gpu加速处理,通过DataLoader来加载数据,代码如下。

x_test = np.array(test_X)
x_test = torch.tensor(x_test, dtype=torch.float)
if torch.cuda.is_available():
    x_test = x_test.cuda()
test = TensorDataset(x_test)
test_loader = DataLoader(test, batch_size=batch_size, shuffle=False)

4.2 训练

model = MLP(x_train.shape[1], 512, classes, dropout=0.3)
if torch.cuda.is_available():
    model = model.cuda()
    
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
loss_fn = torch.nn.CrossEntropyLoss()  #多分类
#loss_fn = torch.nn.BCEWithLogitsLoss() #二分类
#loss_fn = torch.nn.L1Loss()   #回归

y_pred = model(x_batch)
loss = loss_fn(y_pred, y_batch)
optimizer.zero_grad()       
loss.backward()             
optimizer.step()            

定义完网络后,如果存在GPU,则需要将model也添加上gpu。优化函数同keras一样,都含有adam,sgd等。损失函数针对不同问题有所不同,在代码中已有标注,上面列出的分类任务都采用的是交叉熵损失函数,集成了最后一层的激活函数,如多分类的CrossEntropyLoss,它已经集成了softmax函数,且不需要对类别类别做onehot处理,直接输入int值即可。

  • optimizer.zero_grad():是为下一次训练清除梯度值
  • loss.backward()是反向传播,计算每个参数的梯度值
  • optimizer.step():是更新参数权重值包括,weights和biases

4.3 预测

在预测中eval()函数会自动把BN和DropOut固定住,不会取平均,而是用训练好的值。不然的话,一旦test的batch_size过小,很容易就会被BN层导致生成图片颜色失真极大。代码如下

model.eval()
y_pred = model(x_batch)
test_preds_fold[i * batch_size:(i + 1) * batch_size] = y_pred.cpu().numpy()

由于计算的结果是tensor,需要转为numpy。

最终的结果转化同keras一样,如二分类需要设定阈值。

代码地址:data_mining_models

写在最后

关注公号:

AI成长社
ML与DL的成长圣地。
知乎专栏:ML与DL成长之路

相关文章

  • pytorch做二分类,多分类以及回归任务

    【lightgbm/xgboost/nn代码整理四】pytorch做二分类,多分类以及回归任务 1.简介 本不打算...

  • xgboost做二分类,多分类以及回归任务

    【lightgbm/xgboost/nn代码整理二】xgboost做二分类,多分类以及回归任务 1.简介 该部分是...

  • keras做二分类,多分类以及回归任务

    【lightgbm/xgboost/nn代码整理三】keras做二分类,多分类以及回归任务 1.简介 该部分是比较...

  • 逻辑回归

    逻辑回归是一个分类算法,利用回归来做分类.它可以处理二元分类以及多元分类,逻辑回归与线性回归不同主要体现在以下两点...

  • 逻辑回归模型(LR)

    1.模型概念 逻辑回归模型是一种分类模型,它可以处理二院分类以及多分类的任务。我们知道,线性回归的模型是求...

  • pytorch中的损失函数

    1. 多标签分类损失函数 pytorch中能计算多标签分类任务loss的方法有好几个。binary_cross_e...

  • 线性回归

    什么叫回归,什么叫分类? 对连续型变量做预测叫回归,对离散型变量做预测叫分类 线性回归的主要任务是什么? 线性回归...

  • 《机器学习》第3章

    回归和分类的区别 对连续型变量做预测叫回归,对离散型变量做预测叫分类(好瓜坏瓜) 线性回归 1、 线性回归的任务 ...

  • 《机器学习》第3章

    回归和分类的区别 对连续型变量做预测叫回归,对离散型变量做预测叫分类(好瓜坏瓜) 线性回归 1、 线性回归的任务 ...

  • 神经网络基本介绍和建立流程

    神经网络的作用 神经网络可以解决分类任务和回归任务,通常我们解决的任务也就是分类和回归。比如分类任务中有 关于糖尿...

网友评论

    本文标题:pytorch做二分类,多分类以及回归任务

    本文链接:https://www.haomeiwen.com/subject/cpbgectx.html