美文网首页
PyTorch 实现简单的二分类器

PyTorch 实现简单的二分类器

作者: 刘小白DOER | 来源:发表于2022-04-21 15:26 被阅读0次

        今天测试一个使用PyTorch 来完成sklearn的moon数据分类的案例。

        本文基于MyML/pytorch moons.py at master · prudvinit/MyML (github.com)

    1、生成moon数据,一共 200 个样本,并使用scatter画出散点图

    2、将样本数据从 numpy 转成 tensor

    3、构建全连接的神经网络,网络包含一个输入层,一个中间层,一个输出层。中间层包含 3 个神经元,使用的激活函数是 tanh,softmax函数计算概率得分,根据大小判断为0或者1。

    整个网络连接情况如下:

    4、损失函数用 CrossEntropyLoss,梯度优化器使用 Adam

    5、开始训练及计算training error,accuracy_score得分0.97

    6、根据loss画出曲线

        plt.plot(losses,linewidth=1)

    7、更直观地展示分类结果,将结果可视化

    相关文章

      网友评论

          本文标题:PyTorch 实现简单的二分类器

          本文链接:https://www.haomeiwen.com/subject/vaiiertx.html