今天测试一个使用PyTorch 来完成sklearn的moon数据分类的案例。
本文基于MyML/pytorch moons.py at master · prudvinit/MyML (github.com)
1、生成moon数据,一共 200 个样本,并使用scatter画出散点图
2、将样本数据从 numpy 转成 tensor
3、构建全连接的神经网络,网络包含一个输入层,一个中间层,一个输出层。中间层包含 3 个神经元,使用的激活函数是 tanh,softmax函数计算概率得分,根据大小判断为0或者1。
整个网络连接情况如下:
4、损失函数用 CrossEntropyLoss,梯度优化器使用 Adam
5、开始训练及计算training error,accuracy_score得分0.97
6、根据loss画出曲线
plt.plot(losses,linewidth=1)
7、更直观地展示分类结果,将结果可视化
网友评论