在上篇文章中介绍了 CNN 手写数字识别实战中的整体流程,满足了快速入门上手的需求。
本文是作者在学习过程发现的一些 ‘彩蛋’,感觉挺好用,记录一下~
Tqdm 进度条
Tqdm 是一个快速,可扩展的Python进度条,可以在 Python 长循环中添加一个进度提示信息,用户只需要封装任意的迭代器 tqdm(iterator)。毕竟模型训练等待的过程属实无聊,有个进度条看还是很友好的,用起来也很简单~

模型断点训练
在我们进行模型训练的时候,会因为一些其他因素导致模型训练中断,此时如果再重新训练难免会浪费时间,我们更希望模型能够从断点开始训练。
""""
用于捕捉KeyboardInterrupt错误,效果比try except好得多
可以人为终止训练,并将训练得到的参数保存下来,实现断点训练
"""
def exit(signum, frame):
print("Model Saved")
t.save(net.state_dict(), 'conv.pth')
# 通过raise显示地引发异常
raise KeyboardInterrupt
# signal.signal(sig,action) sig为某个信号,action为该信号的处理函数
signal.signal(signal.SIGINT, exit) # 终止 由键盘引起的终端(Ctrl-c)
signal.signal(signal.SIGTERM, exit) # 终止 进程终止(进程可捕获)
加载模型文件
# 尝试加载模型参数
if os.path.exists('conv.pth'):
try:
net.load_state_dict(t.load('conv.pth'))
except Exception as e:
print(e)
print("Parameters Error")
训练过程可视化
在模型训练过程中,通过可视化来描绘 loss 的变化,可以帮助我们更直观的查看模型训练过程。在模型训练中加入以下代码即可:
# 用于绘制动态图
plt.ion()
losses = []
# 模型训练
EPOCHS = 100
BATCH_SIZE = 500
LR = 0.001
for epoch in tqdm(range(EPOCHS)):
if epoch % 100 == 0:
for param_group in optim.param_groups:
LR = LR * 0.9
param_group['lr'] = LR
index = 0
for i in tqdm(range(int(len(data) / BATCH_SIZE)), total=int(len(data) / BATCH_SIZE)):
batch_x = data[index:index + BATCH_SIZE]
batch_y = y[index:index + BATCH_SIZE]
prediction = net.forward(batch_x)
loss = criterion(prediction, batch_y)
optim.zero_grad()
loss.backward()
optim.step()
index = index + BATCH_SIZE
print(loss.item())
#用于绘制动态图
if loss <= 0.3:
losses.append(loss)
plt.plot(losses)
#延时关闭
plt.pause(0.001)
plt.ioff()
测试效果图:

网友评论