感谢伯禹学习平台,本次学习将记录记录如何使用Pytorch高效实现网络,熟练掌握Pytorch的基础知识,记录不包含理论知识的细节展开。
一:线性回归-Pytorch实现
#导入相关的包
import torch
from torch import nn
#pytorch网络需要继承 nn.Module,定义结果,完成前向forward
class LinearNet(nn.Module):
def __init__(self, n_feature):
#继承父类
super(LinearNet, self).__init__()
# 线性层,n_feature表示输入X的特征数目,1表示输出Y的维度只有1
self.linear = nn.Linear(n_feature, 1)
#也可以如下定义,init后面加入o_feature实现输出维度自定义
#self.linear = nn.Linear(n_feature, o_feature)
def forward(self, x):
y = self.linear(x)
return y
# 调用网络只需要新建一个该网络的对象即可
net = LinearNet(n_feature)
# 损失函数和优化算子,源代码里有很多,可自行查看
loss = nn.loss = nn.MSELoss()
import torch.optim as optim
# SGD最常用,注意net.parameters()可以有很多种写法,分层学习率也可以在此设置,lr表示学习率
optimizer = optim.SGD(net.parameters(), lr=0.03)
训练网络的基本步骤
num_epochs =3
for epoch in range(1, num_epochs + 1):
#data_iter参考pytorch定义的Dataloader
for X, y in data_iter:
output = net(X)
l = loss(output, y.view(-1, 1))
# 以下三步比较主要,zero_grad防止梯度累计,清空上个batch的梯度影响
optimizer.zero_grad()
# 反向传播
l.backward()
optimizer.step()
print('epoch %d, loss: %f' % (epoch, l.item()))
二:Softmax-Pytorch实现
Softmax只不过是定义了最后的输出加权求和为1,通常是将回归转到一个多分类问题,其本身也有很多限制。因此其相对于LR网络本身的区别在于最后一层的输出,而Pytorch则更加简便实现。
#原有的线性网络可以不做太多改变
class LinearNet(nn.Module):
def __init__(self, n_feature,o_feature):
super(LinearNet, self).__init__()
# Pytorch通常直到线性层,因为Softmax可以在损失函数里计算
self.linear = nn.Linear(n_feature, o_feature)
def forward(self, x):
# 如果输入x的维度为 [batch_size,28,28], 则需要对x重新reshape
y = self.linear(x.view(x.shape[0], -1))
return y
定义损失函数
# CrossEntropyLoss 真实标签不使用one-hot编码
loss = nn.CrossEntropyLoss()
三:MLP(多层感知机)-Pytorch实现
MLP的定义不同的的地方略有差异。我理解其本身实在LR的基础上多引入层数,在层后面可以连接不同的激活函数,以增强网络的拟合能力。虽然网络很简单有时候直接使用Sequential更快,但是Pytorch用多了就会发现类还是比较顺手(个人观点)。
class MLPNet(nn.Module):
def __init__(self, num_inputs, num_outputs):
super(LinearNet, self).__init__()
self.linear1 = nn.Linear(num_inputs, 256)
#relu可以重复调用,或者直接使用F.relu()
self.relu = nn.ReLU()
self.linear2 = nn.Linear(256, num_outputs)
def forward(self, x):
x = self.linear1(x.view(x.shape[0], -1))
x = self.relu(x)
x = self.linear2(x)
x = self.relu(x)
return x
net = LinearNet(num_inputs,num_outputs)
除了Relu还有很多别的激活函数,sigmoid,tanh等等。
网友评论