美文网首页
PyTorch张量基础操作

PyTorch张量基础操作

作者: 赤色要塞满了 | 来源:发表于2022-03-30 16:51 被阅读0次

1. 安装

参考官网,略。

2. 数据操作

创建:

import torch

# 可指定dtype和device
torch.empty(5, 3)
torch.rand(5, 3)
torch.randn(5, 3)
torch.zeros(5, 3)
torch.tensor([1,2,3])
torch.ones(5, 3)
torch.eye(5, 3)
torch.arange(1, 20, 3)
torch.linspace(1, 20, 5)
torch.normal(1, 1)
torch.uniform(1, 20)
torch.randperm(10)
torch.randn(1).item()

PyTorch操作inplace版本都有后缀_, 例如x.copy_(y), x.t_()
索引选择出来的结果与原数据共享内存。

x = torch.randn(5, 3)
torch.index_select(x, 1, torch.tensor([0, 2]))

# 替换维度后按索引采集 0-替换行 1-替换列
t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
torch.gather(t, 1, torch.tensor([[2, 0, 1], [0, 1, 2], [1, 0, 2]]))

view()也共享数据,但内存地址不同。clone()不共享数据,但记录在计算图中,梯度回传到副本时也会传到源Tensor。

简单计算:

t.trace()
t.diag()
t.triu()
t.tril()
t.t()

x = torch.randn(3, 3, dtype=torch.float64)
x.inverse() # 逆矩阵

a = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
a.svd() # 奇异值分解

a = torch.tensor([1, 2], dtype=torch.float32)
b = torch.tensor([4, 5], dtype=torch.float32)
a.dot(b) # 只支持一维向量内积、点积

a = torch.tensor([[1, 2, 3], [3, 4, 5]], dtype=torch.float32)
b = torch.tensor([[1, 2, 3, 4], [2, 3, 4, 5], [4, 5, 6, 7]], dtype=torch.float32)
x = torch.ones(2, 4)
torch.mm(a, b) # 矩阵乘法,内积/点积
torch.addmm(x, a, b) # 乘完再加

a = torch.tensor([[1, 2, 3], [3, 4, 5]], dtype=torch.float32)
b = torch.tensor([[2, 3, 4], [4, 5, 6]], dtype=torch.float32)
a.cross(b, dim=1) # 外积、叉积

a.numpy()
b = torch.from_numpy(np.ones(5))

类型dtype转换:

a.int()
a.to(torch.int32)
a.float()
a.to(torch.float32)
a.type('torch.FloatTensor')

GPU:

if torch.cuda.is_available():
    device = torch.device("cuda")
    y = torch.ones_like(x, device=device)
    x = x.to(device)  

计算梯度:

x = torch.ones(2, 2, requires_grad=True)
x.grad_fn
x.is_leaf
x.requires_grad_(True)
y = x * 2
y.backward() # 标量,否则传入y同形状的Tensor
x.grad
x.grad.data.zero_() # 重新backward需要清零,否则叠加
x.data *= 100 # 只改变值,不会记录在计算图
with torch.no_grad(): # 不追踪梯度
    y2 = x ** 3

相关文章

网友评论

      本文标题:PyTorch张量基础操作

      本文链接:https://www.haomeiwen.com/subject/vjaijrtx.html