美文网首页
操练代码之优化器

操练代码之优化器

作者: 万州客 | 来源:发表于2024-08-18 19:55 被阅读0次

就是Adam,Adagrad,RMSprop,SGD,Momentum这5个优化器。

一,代码

import torch
import torch.nn

import torch.utils.data as Data
import matplotlib
import matplotlib.pyplot as plt

matplotlib.rcParams['font.sans-serif'] = ['SimHei']

x = torch.unsqueeze(torch.linspace(-1, 1, 500), dim=1)
y = x.pow(3)

LR = 0.01
batch_size = 15
epoches = 5
torch.manual_seed(10)

dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2
)

class Net(torch.nn.Module):
    def __init__(self, n_input, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden_layer = torch.nn.Linear(n_input, n_hidden)
        self.output_layer = torch.nn.Linear(n_hidden, n_output)

    def forward(self, input):
        x = torch.relu(self.hidden_layer(input))
        output = self.output_layer(x)
        return output


def train():
    net_SGD = Net(1, 10, 1)
    net_Monmentum = Net(1, 10, 1)
    net_AdaGrad = Net(1, 10, 1)
    net_RMSprop = Net(1, 10, 1)
    net_Adam = Net(1, 10, 1)
    nets = [net_SGD, net_Monmentum, net_RMSprop, net_AdaGrad, net_Adam]

    optimizer_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)
    optimizer_Momentum = torch.optim.SGD(net_SGD.parameters(), lr=LR, momentum=0.6)
    optimizer_RMSprop = torch.optim.RMSprop(net_SGD.parameters(), lr=LR, alpha=0.9)
    optimizer_AdaGrad = torch.optim.Adagrad(net_SGD.parameters(), lr=LR, lr_decay=0)
    optimizer_Adam = torch.optim.Adam(net_SGD.parameters(), lr=LR, betas=(0.9, 0.99))
    optimizers = [optimizer_SGD, optimizer_Momentum, optimizer_RMSprop, optimizer_AdaGrad, optimizer_Adam]

    loss_function = torch.nn.MSELoss()
    losses = [[], [], [], [], []]

    for epoch in range(epoches):
        for step, (batch_x, batch_y) in enumerate(loader):
            for net, optimizer, loss_list in zip(nets, optimizers, losses):
                pred_y = net(batch_x)
                loss = loss_function(pred_y, batch_y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                loss_list.append(loss.data.numpy())
    print(losses)
    plt.figure(figsize=(12, 7))
    labels = ['SGD', 'Momentum', 'RMSprop', 'AdaGrad', 'Adam']
    for i, loss in enumerate(losses):
        print(loss, '----------')
        plt.plot(loss, label=labels[i])
    plt.legend(loc='upper right', fontsize=15)
    plt.tick_params(labelsize = 13)
    plt.xlabel('train step', size=15)
    plt.ylabel('model loss', size=15)
    plt.ylim((0, 0.3))
    plt.show()

if __name__ == '__main__':
    train()

二,输出截图


2024-08-19 19_52_29-ch2 – 333.py.png

相关文章

  • 编译器前端和后端

    编译器粗略分为词法分析,语法分析,类型检查,中间代码生成,代码优化,目标代码生成,目标代码优化。把中间代码生成及之...

  • 编译原理——寄存器

    •代码生成是编译器的最后阶段。代码生成器通过前端产生的中间表示法或者通过代码优化器在代码优化阶段,映射到目标程序中...

  • iOS的性能优化

    1、ipa包体积优化 1.1 编译配置优化:编译器代码层面优化Optimize Level;Bitcode(较难...

  • 代码文件编译生成过程完成的事情

    编译过程可分为6步:扫描(词法分析)、语法分析、语义分析、源代码优化、代码生成、目标代码优化。 词法分析:扫描器(...

  • 网页性能优化

    主要内容如下 代码层面的优化 缓存 http 减小打包体积 代码层面的优化 csswill-change告诉浏览器...

  • 编译原理概述

    编译器原理 词法分析器 语法分析器 语义分析器 中间代码生成 符号表 独立机器的代码优化器 代码生成器 依赖于机器...

  • 编译器优化部分代码

    我们简单写一些代码看编译器优化前后的对比。编译器没有优化时 在Build Setting 搜索optimizati...

  • 编译器优化

    首先我们先看以下代码: 编译器优化优化的是什么呢,优化的是底层代码执行逻辑,使项目执行更加高效。汇编是最接近底层的...

  • 0003-keras自定义优化器

    原文 keras优化器的代码 自定义一个SGD优化器 实现“软batch” 假如模型比较庞大,自己的显卡最多也就能...

  • 编译器想做什么

    编译器就程序员写的代码变成CPU能理解机器代码。编译器的指令重排指开启编译器优化后,在不影响代码行为的前提下,代码...

网友评论

      本文标题:操练代码之优化器

      本文链接:https://www.haomeiwen.com/subject/dlnpkjtx.html