今天测试一个使用PyTorch 来完成sklearn的moon数据分类的案例。
本文基于MyML/pytorch moons.py at master · prudvinit/MyML (github.com)
1、生成moon数据,一共 200 个样本,并使用scatter画出散点图
![](https://img.haomeiwen.com/i24447700/8316a96b19bc7d7c.png)
2、将样本数据从 numpy 转成 tensor
![](https://img.haomeiwen.com/i24447700/7860fe5272717e07.png)
3、构建全连接的神经网络,网络包含一个输入层,一个中间层,一个输出层。中间层包含 3 个神经元,使用的激活函数是 tanh,softmax函数计算概率得分,根据大小判断为0或者1。
![](https://img.haomeiwen.com/i24447700/f1a11f9cb3837c73.png)
整个网络连接情况如下:
![](https://img.haomeiwen.com/i24447700/f5725f578b187bf5.png)
4、损失函数用 CrossEntropyLoss,梯度优化器使用 Adam
![](https://img.haomeiwen.com/i24447700/5a791ab302169b92.png)
5、开始训练及计算training error,accuracy_score得分0.97
![](https://img.haomeiwen.com/i24447700/aec758d21b01d619.png)
6、根据loss画出曲线
plt.plot(losses,linewidth=1)
![](https://img.haomeiwen.com/i24447700/f213ef93d5a4debd.png)
7、更直观地展示分类结果,将结果可视化
![](https://img.haomeiwen.com/i24447700/b30b5349672e9dad.png)
![](https://img.haomeiwen.com/i24447700/bdbfc188b7fa09bf.png)
网友评论