美文网首页
【PyTorch实战】Fully Connected Netw

【PyTorch实战】Fully Connected Netw

作者: HE_EH | 来源:发表于2017-08-22 17:45 被阅读0次

    1. 简介

    (1) 结构

    简单的三层结构,第一层为输入层,第二层为隐藏层,第三层为输出层

    (2) 激活函数

    2.模型设计

    (1) Model

    import torch.nn as nn

    from collections import OrderedDict

    layers = OrderedDict()      # 创建顺序的dict结构

    for i, n_hidden in enumerate(n_hiddens):

        layers['fc{}'.format(i+1)] = nn.Linear(current_dims, n_hidden) 

        layers['relu{}'.format(i+1)] = nn.ReLU()

        layers['drop{}'.format(i+1)] = nn.Dropout(0.2)

        current_dims = n_hidden

    layers['out'] = nn.Linear(current_dims, n_class)

    model = nn.Sequential(layers)    # 顺序的执行layers

    print(model)

    model = torch.nn.DataParallel(model, device_ids= range(ngpu)) # 数据并行

    (2) Optimizer

    # 采用随机梯度下降算法

    # lr表示学习率

    # weight_decay表示权重衰减,防止模型过拟合

    # momentum加速模型的迭代,参见3

    optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.0001, momentum=0.9) 

    for epoch in range(epochs):

        model.train() # 训练模式

        if epoch in [80,200]:

            optimizer.param_groups[0]['lr'] *= 0.1 # 随着模型跌打的次数增加,学习率降低

        for batch_idx, (data, target) in enumerate(train_data):

            data, target = Variable(data), Variable(target)

            optimizer.zero_grad() # 清除上一轮的梯度,否则会进行累加

            output = model(data) # 动态图结构

            loss = F.cross_entropy(output, target) # 交叉熵

            loss.backward()

            optimizer.step()

    (3) Evaluation

    # 模型训练时,按照epoch的次数进行效果评估

    if epoch % 10 == 0:

        model.eval() #  评估模式

        test_loss = 0

        correct = 0

        for data, target in test_date:

            data, target = Variable(data, volatile=True), Variable(target)

            output = model(data)

            test_loss += F.cross_entropy(output, target).data[0]

            pred = output.data.max(1)[1]

            correct += pred.cpu().eq(indx_target).sum()

        test_loss = test_loss / len(test_loader)

        acc = 100. * correct / len(test_loader.dataset)

    3. 参考资料

    1. https://github.com/aaron-xichen/pytorch-playground

    2. PyTorch API

    3. On the importance of initialization and momentum in deep learning

    4.Fully Connected Neural Network Algorithms

    5.Fully Connected Neural Network与Activation Function

    相关文章

      网友评论

          本文标题: 【PyTorch实战】Fully Connected Netw

          本文链接:https://www.haomeiwen.com/subject/nzildxtx.html