美文网首页
CNN-学习彩蛋

CNN-学习彩蛋

作者: 自由调优师_大废废 | 来源:发表于2019-12-12 14:51 被阅读0次

在上篇文章中介绍了 CNN 手写数字识别实战中的整体流程,满足了快速入门上手的需求。
本文是作者在学习过程发现的一些 ‘彩蛋’,感觉挺好用,记录一下~

Tqdm 进度条

Tqdm 是一个快速,可扩展的Python进度条,可以在 Python 长循环中添加一个进度提示信息,用户只需要封装任意的迭代器 tqdm(iterator)。毕竟模型训练等待的过程属实无聊,有个进度条看还是很友好的,用起来也很简单~


image.png

模型断点训练

在我们进行模型训练的时候,会因为一些其他因素导致模型训练中断,此时如果再重新训练难免会浪费时间,我们更希望模型能够从断点开始训练。

""""
用于捕捉KeyboardInterrupt错误,效果比try except好得多
可以人为终止训练,并将训练得到的参数保存下来,实现断点训练
"""
def exit(signum, frame):
    print("Model Saved")
    t.save(net.state_dict(), 'conv.pth')
    # 通过raise显示地引发异常
    raise KeyboardInterrupt
# signal.signal(sig,action) sig为某个信号,action为该信号的处理函数
signal.signal(signal.SIGINT, exit)   # 终止   由键盘引起的终端(Ctrl-c)
signal.signal(signal.SIGTERM, exit)  # 终止   进程终止(进程可捕获)

加载模型文件

# 尝试加载模型参数
if os.path.exists('conv.pth'):
    try:
        net.load_state_dict(t.load('conv.pth'))
    except Exception as e:
        print(e)
        print("Parameters Error")

训练过程可视化

在模型训练过程中,通过可视化来描绘 loss 的变化,可以帮助我们更直观的查看模型训练过程。在模型训练中加入以下代码即可:

# 用于绘制动态图
plt.ion() 
losses = []

# 模型训练
EPOCHS = 100
BATCH_SIZE = 500
LR = 0.001
for epoch in tqdm(range(EPOCHS)):
    if epoch % 100 == 0:
        for param_group in optim.param_groups:
            LR = LR * 0.9
            param_group['lr'] = LR
    index = 0
    for i in tqdm(range(int(len(data) / BATCH_SIZE)), total=int(len(data) / BATCH_SIZE)):
        batch_x = data[index:index + BATCH_SIZE]
        batch_y = y[index:index + BATCH_SIZE]
        prediction = net.forward(batch_x)
        loss = criterion(prediction, batch_y)
        optim.zero_grad()
        loss.backward()
        optim.step()
        index = index + BATCH_SIZE
        print(loss.item())

        #用于绘制动态图
        if loss <= 0.3:
            losses.append(loss)
        plt.plot(losses)
        #延时关闭
        plt.pause(0.001)
plt.ioff()

测试效果图:


image.png

相关文章

网友评论

      本文标题:CNN-学习彩蛋

      本文链接:https://www.haomeiwen.com/subject/aszzgctx.html