美文网首页程序员
线性回归(MXNet 版本 gluon)

线性回归(MXNet 版本 gluon)

作者: zidea | 来源:发表于2019-09-14 03:44 被阅读0次
    mxnet.jpg

    我们可以使用 MXNet 的 gluon 模块来使用预定义方法来简化我们神经网络

    net = gluon.nn.Sequential()
    

    当我们手写模型的时候,需要声明模型参数,然后再使用他们来构建模型。但在 gluon 提供了大量预先定义好的层,这样我们就只需要关注使用哪些层来构建模型。这里使用 Sequential 来层串起来,预先定义一个空的模型

    net.add(gluon.nn.Dense(1))
    

    Dense 其实就是一个全连接层,到这里看起开有点 keras 的设计,都差不多了。不过 MXNet 是国人感觉更亲近。这里简化了,我们在 Dense 无需指定输入节点是多少,MXNet 会根据你的输入进行计算,只需要指定输出就可以。

    net = gluon.nn.Sequential()
    net.add(gluon.nn.Dense(1))
    
    print(net)
    
    Sequential(
      (0): Dense(None -> 1, linear)
    )
    
    epochs = 5
    batch_size = 10
    for e in range(epochs):
        total_loss = 0
        for data, label in data_iter:
            with ag.record():
                output = net(data)
                loss = square_loss(output,label)
            loss.backward()
            trainer.step(batch_size)
            total_loss += nd.sum(loss).asscalar()
        print("Epoch %d, average loss: %f " % (e, total_loss/num_examples))
    dense = net[0]
    print(true_w,dense.weight.data())
    
    Epoch 0, average loss: 0.917113
    Epoch 1, average loss: 0.000049
    Epoch 2, average loss: 0.000049
    Epoch 3, average loss: 0.000050
    Epoch 4, average loss: 0.000049
    ([2, -3.4],
    [[ 1.9996808 -3.399679 ]]
    

    相关文章

      网友评论

        本文标题:线性回归(MXNet 版本 gluon)

        本文链接:https://www.haomeiwen.com/subject/kahhectx.html