李沐 《动手学深度学习》 第三章
from typing import Iterable
import torch
import torchvision
def load_data_fashion_mnist(batch_size, resize=None, root='./FashionMNIST'):
"""Download the fashion mnist dataset and then load into memory."""
trans = []
if resize:
trans.append(torchvision.transforms.Resize(size=resize))
trans.append(torchvision.transforms.ToTensor())
transform = torchvision.transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform)
mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)
num_workers = 4
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
return train_iter, test_iter
读取数据
batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)
定义和初始化模型
num_inputs = 784
num_output = 10
class FlattenLayer(torch.nn.Module):
"""
将 x 打平
"""
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
将 x 打平, x.shape: (B, 28, 28) 打平后变成 x.shape: (B, 28*28)
:param x:
:return:
"""
return x.view(x.shape[0], -1)
class Net(torch.nn.Module):
"""
线性模型
"""
def __init__(self):
super().__init__()
self.flatten = FlattenLayer()
self.linear = torch.nn.Linear(in_features=num_inputs,
out_features=num_output)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
进行运算
:param x:
:return:
"""
x = self.flatten(x)
y = self.linear(x)
return y
net = Net()
print("net ...")
初始化参数
torch.nn.init.normal_(net.linear.weight, mean=0, std=0.01)
torch.nn.init.constant_(net.linear.bias, val=0)
也可以通过 name paramter 来初始化
for name, param in net.named_parameters():
if name.endswith(".weight"):
torch.nn.init.normal_(param, mean=0., std=0.01)
elif name.endswith(".bias"):
torch.nn.init.constant_(param, val=0)
else:
raise RuntimeError(f"{name} not be init")
定义损失函数
loss = torch.nn.CrossEntropyLoss(reduction="mean")
定义优化算法
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
训练模型
num_epochs = 5
def train():
print("begin train...")
for epoch in range(1, num_epochs + 1):
total_loss = 0.
total = 0
true_sum = 0
for x, y in train_iter:
y_hat = net(x)
# y_hat.shape: B*num_output, y.shape: B*1
# ll 是标量, 默认mean
ll = loss(y_hat, y)
# 进行优化
# 清空grad
optimizer.zero_grad()
ll.backward()
optimizer.step()
# ll是 mean, 所以需要乘以 y.shape[0]
total += y.shape[0]
total_loss += ll * y.shape[0]
# 计算 acc
true_sum += torch.sum(torch.argmax(y_hat, dim=-1) == y).item()
epoch_loss = total_loss / total
acc = true_sum / total
print(f"epoch: {epoch}, total: {total}, loss: {epoch_loss}, acc: {acc}")
开始训练
train()
网友评论