美文网首页程序员
线性回归(MXNet 版本 gluon)

线性回归(MXNet 版本 gluon)

作者: zidea | 来源:发表于2019-09-14 03:44 被阅读0次
mxnet.jpg

我们可以使用 MXNet 的 gluon 模块来使用预定义方法来简化我们神经网络

net = gluon.nn.Sequential()

当我们手写模型的时候,需要声明模型参数,然后再使用他们来构建模型。但在 gluon 提供了大量预先定义好的层,这样我们就只需要关注使用哪些层来构建模型。这里使用 Sequential 来层串起来,预先定义一个空的模型

net.add(gluon.nn.Dense(1))

Dense 其实就是一个全连接层,到这里看起开有点 keras 的设计,都差不多了。不过 MXNet 是国人感觉更亲近。这里简化了,我们在 Dense 无需指定输入节点是多少,MXNet 会根据你的输入进行计算,只需要指定输出就可以。

net = gluon.nn.Sequential()
net.add(gluon.nn.Dense(1))

print(net)
Sequential(
  (0): Dense(None -> 1, linear)
)
epochs = 5
batch_size = 10
for e in range(epochs):
    total_loss = 0
    for data, label in data_iter:
        with ag.record():
            output = net(data)
            loss = square_loss(output,label)
        loss.backward()
        trainer.step(batch_size)
        total_loss += nd.sum(loss).asscalar()
    print("Epoch %d, average loss: %f " % (e, total_loss/num_examples))
dense = net[0]
print(true_w,dense.weight.data())
Epoch 0, average loss: 0.917113
Epoch 1, average loss: 0.000049
Epoch 2, average loss: 0.000049
Epoch 3, average loss: 0.000050
Epoch 4, average loss: 0.000049
([2, -3.4],
[[ 1.9996808 -3.399679 ]]

相关文章

网友评论

    本文标题:线性回归(MXNet 版本 gluon)

    本文链接:https://www.haomeiwen.com/subject/kahhectx.html