美文网首页
Pytorch机器学习——2 安装和快速上手(四)

Pytorch机器学习——2 安装和快速上手(四)

作者: 辘轳鹿鹿 | 来源:发表于2022-04-01 11:03 被阅读0次

outline

  1. PyTorch安装
  2. Jupyter Notebook使用
  3. NumPy基础知识
  4. PyTorch基础知识

4. PyTorch基础知识

PyTorch的特色之一是提供构建动态计算图的框架,这样网络结构就不再是一成不变的了,甚至可以在运行时修正它们。
在神经网络方面,PyTorch的优点还在于使用了多GPU的强大加速能力、自定义数据加载器和极简的预处理过程等。

4.1 Tensor简介

  • Tensor是PyTorch中的基本对象,意思为张量,表示多维的矩阵,是PyTorch中的基本操作对象之一。与Numpy的ndarray类似,Tensor的声明和获取size可以这样:
import torch
x=torch.Tensor(5,3)
x.size()
  • Tensor的算术运算和选取操作与Numpy一样,因此Numpy相似的运算操作都可以迁移过来
  • Tensor与Numpy的array还可以进行互相转换,有专门的转换函数:
x=torch.rand(5,3)
y=x.numpy()
z=torch.from_numpy(y)
image.png

4.2 Variable 简介

  • Variable是PyTorch的另一个基本对象,可以把它理解为是对Tensor的一个封装。Variable用于放入计算图中以进行前向传播、反向传播和自动求导。
  • 在一个Variable中有三个重要属性
    • data:表示包含的Tensor数据部分
    • grad:表示传播方向的梯度,这个属性是延迟分配的,而且仅允许进行一次
    • creator:表示创建这个Variable的Function的引用,该引用用于回溯整个创建链路。如果是用户创建的Variable,其creator为None,同时这种Variable称作Leaf Variable,autograd只会给Leaf Variable分配梯度。
from torch.autograd import variable
x=torch.rand(4)
print(x)
x=variable(x,requires_grad = True) #声明requires_grad = True,就必须指定grad_variables
y=x*3
grad_variables = torch.FloatTensor([1,2,3,4])
y.backward(grad_variables) #grad_variables 就是y求导时的梯度参数
x.grad
image.png

4.3 CUDA简介

  • 如果安装了支持CUDA版本的PyTorch,就可以启用显卡运算了。torch.cuda用于设置和运行CUDA操作,它会记录当前选择的GPU,并且分配的所有CUDA张量将默认在上面创建,可以使用torch.cuda.device上下文管理器更改所选设备。
  • 除非启用对等存储器访问,否则对于分布不同设备上的张量,任何启动操作的尝试都将引发错误。
torch.cuda.is_available()
x=x.cuda()
y=y.cuda()
x+y

下面是上述代码的运行结果,建议在服务器上跑


image.png

4.4 模型的保存与加载

  • Python中对于模型数据的保存和加载操作都是引用Python内置的pickle包,使用pickle.dump()pickle.load()方法
  • PyTorch中也有同样功能的方法提供
import torch
torch.save(model, 'model.pkl')  #保存整个模型
model = torch.load('model.pkl')  #加载整个模型
torch.save(alexnet.state_dict(), 'params.pkl') #保存网络中的参数
alexnet.load_state_dict(torch.load('params.pkl')) #加载网络中的参数
  • 在torchvision.models模块里,PyTorch提供了一些常用的模型:AlexNet、VGG、ResNet、SqueezeNet、DenseNet、Inception v3

  • 以下为加载部分预训练模型的方法,其中model是指我们自己要训练的模型,需要我们自己预先定义

import torch.utils.model_zoo as model_zoo
#加载这类预训练模型的过程中,还可以进行微处理
model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
}
pretrained_dict = model_zoo.load_url(model_urls['resnet18'])
"""
resnet18 = models.resnet18(pretrained=True)
pretrained_dict = resnet18.state_dict()
"""
model_dict = model.state_dict()
pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in model_dict} #将pretrained_dict里不属于model_dict的键剔除掉
model_dict.update(pretrained_dict) # 更新现有的model_dict
model.load_state_dict(model_dict) # 加载我们真正需要的state_dict

有用就留个赞吧^_^

相关文章

网友评论

      本文标题:Pytorch机器学习——2 安装和快速上手(四)

      本文链接:https://www.haomeiwen.com/subject/uwvgjrtx.html