美文网首页
Pytorch 常用语法

Pytorch 常用语法

作者: Hiper | 来源:发表于2020-10-29 13:16 被阅读0次

常用模块以及设置

import torch
import numpy as np
from matplotlib import pyplot as plt

dtype = torch.double
device = torch.device("cuda:0")

创建张量

# 转化np矩阵
x = torch.Tensor(x)

# 创建一维等距向量
x = torch.linspace(0, 1, 100, dtype=dtype, device=device)

# 创建全一矩阵,零矩阵
x = torch.ones(n, m, dtype=dtype, device=device)
x = torch.one_like(x, dtype=dtype, device=device)
x = torch.zeros(n, m, dtype=dtype, device=device)
x = torch.one_like(x, dtype=dtype, device=device)

# 创建随机矩阵
x = torch.rand(n, m, dtype=dtype, device=device)
x = torch.randn(n, m, dtype=dtype, device=device)
x = torch.normal(means, std, dtype=dtype, device=device)

张量操作

# 增加维度
x = x.unsqueeze(dim)    # dim=0,1,...

# 转置
x = x.t()

# 大小
print(x.size())

# 切片
x_1 = x[:,1:-2]

常用函数

# 数学函数
y = torch.sin(x)
y = torch.tan(x)
y = torch.atan(x)
y = torch.sqrt(x)
y = torch.relu(x)
y = torch.tanh(x)
y = torch.sigmoid(x)

# 其他函数
y = torch.sum(x, dim = 0)

模块类

class SLNN(torch.nn.Module):
    def __init__(self, N):
        super(SLNN, self).__init__()
        self.dense1 = torch.nn.Linear(N, N)
        self.dense2 = torch.nn.Linear(N, N)
        self.tanh = torch.tanh()
    
    def forward(x):
        out = self.dense1(x)
        out = self.tanh(out)
        out = self.dense2(out)

损失函数与优化器

criterion = torch.nn.MSELoss(reduction='sum')       # 定义损失函数
optimizer = torch.optim.Adam(model_eign.parameters(), lr=1e-4)      # 优化器

迭代

Epoch = 10000
for epoch in range(Epoch):
    y_pred = model(x)

    loss = criterion(y, y_pred)
    if epoch % 100 == 99:
        print('epoch[{}/{}],loss:{:.6f}'.format(epoch, Epoch, loss.item()))

    optimizer.zero_grad()
    loss.backward(retain_graph=True)
    optimizer.step()

画图

plt.plot(x.cpu(),y.cpu())           # 画图时需要临时转化变量到cpu上
plt.show()

相关文章

网友评论

      本文标题:Pytorch 常用语法

      本文链接:https://www.haomeiwen.com/subject/wrukvktx.html