美文网首页
Pytorch Workflow

Pytorch Workflow

作者: 不到15不改名 | 来源:发表于2019-08-08 21:20 被阅读0次

    Abstract

    Personal understanding of the working paradigm of training an artificial neural network (ANN) based on Pytorch.


    Paradigm

    一、数据(torch.utils.data.DataLoader)
    --> 
    二、模型(torch.nn)
     --> 
    三、策略(损失函数, criterion = torch.nn.BlaBlaLoss)+ 算法(优化算法, optimizer = torch.optim.SGD|Adam|Adadelta...)
     --> 
    四、迭代训练
    (
        FOR 
            1. optimizer.zero_grad() 
            2. outputs_train = net(inputs_train) 
            3. loss_train = criterion(outputs_train, labels_train) 
            4. loss_train.backward() 
            5. optimizer.step() 
        END FOR
    )
    --> 
    五、调参/测试(验证集调参/测试集进行最后的打分)
    (
        with torch.no_grad():
            outputs_test = net(inputs_test)
            loss_test = criterion(outputs_test, labels_test)
            ... other test criterion ...
    )
    --> 
    六、加速(optional)
    (
        # Train on GPU
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        net.to(device) 
        inputs, labels = inputs.to(device), labels .to(device)
        # Data parallelism
        net = nn.DataParallel(net)
    )
    

    to be continued...


    References

    Pytorch tutorial: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
    李航老师: 《统计机器学习》

    相关文章

      网友评论

          本文标题:Pytorch Workflow

          本文链接:https://www.haomeiwen.com/subject/etvfjctx.html