outline
- PyTorch安装
- Jupyter Notebook使用
- NumPy基础知识
- PyTorch基础知识
4. PyTorch基础知识
PyTorch的特色之一是提供构建动态计算图的框架,这样网络结构就不再是一成不变的了,甚至可以在运行时修正它们。
在神经网络方面,PyTorch的优点还在于使用了多GPU的强大加速能力、自定义数据加载器和极简的预处理过程等。
4.1 Tensor简介
- Tensor是PyTorch中的基本对象,意思为张量,表示多维的矩阵,是PyTorch中的基本操作对象之一。与Numpy的ndarray类似,Tensor的声明和获取size可以这样:
import torch
x=torch.Tensor(5,3)
x.size()
- Tensor的算术运算和选取操作与Numpy一样,因此Numpy相似的运算操作都可以迁移过来
- Tensor与Numpy的array还可以进行互相转换,有专门的转换函数:
x=torch.rand(5,3)
y=x.numpy()
z=torch.from_numpy(y)

4.2 Variable 简介
- Variable是PyTorch的另一个基本对象,可以把它理解为是对Tensor的一个封装。Variable用于放入计算图中以进行前向传播、反向传播和自动求导。
- 在一个Variable中有三个重要属性
- data:表示包含的Tensor数据部分
- grad:表示传播方向的梯度,这个属性是延迟分配的,而且仅允许进行一次
- creator:表示创建这个Variable的Function的引用,该引用用于回溯整个创建链路。如果是用户创建的Variable,其creator为None,同时这种Variable称作Leaf Variable,autograd只会给Leaf Variable分配梯度。
from torch.autograd import variable
x=torch.rand(4)
print(x)
x=variable(x,requires_grad = True) #声明requires_grad = True,就必须指定grad_variables
y=x*3
grad_variables = torch.FloatTensor([1,2,3,4])
y.backward(grad_variables) #grad_variables 就是y求导时的梯度参数
x.grad

4.3 CUDA简介
- 如果安装了支持CUDA版本的PyTorch,就可以启用显卡运算了。torch.cuda用于设置和运行CUDA操作,它会记录当前选择的GPU,并且分配的所有CUDA张量将默认在上面创建,可以使用torch.cuda.device上下文管理器更改所选设备。
- 除非启用对等存储器访问,否则对于分布不同设备上的张量,任何启动操作的尝试都将引发错误。
torch.cuda.is_available()
x=x.cuda()
y=y.cuda()
x+y
下面是上述代码的运行结果,建议在服务器上跑

4.4 模型的保存与加载
- Python中对于模型数据的保存和加载操作都是引用Python内置的pickle包,使用pickle.dump()和pickle.load()方法
- PyTorch中也有同样功能的方法提供
import torch
torch.save(model, 'model.pkl') #保存整个模型
model = torch.load('model.pkl') #加载整个模型
torch.save(alexnet.state_dict(), 'params.pkl') #保存网络中的参数
alexnet.load_state_dict(torch.load('params.pkl')) #加载网络中的参数
-
在torchvision.models模块里,PyTorch提供了一些常用的模型:AlexNet、VGG、ResNet、SqueezeNet、DenseNet、Inception v3
-
以下为加载部分预训练模型的方法,其中model是指我们自己要训练的模型,需要我们自己预先定义
import torch.utils.model_zoo as model_zoo
#加载这类预训练模型的过程中,还可以进行微处理
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
}
pretrained_dict = model_zoo.load_url(model_urls['resnet18'])
"""
resnet18 = models.resnet18(pretrained=True)
pretrained_dict = resnet18.state_dict()
"""
model_dict = model.state_dict()
pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in model_dict} #将pretrained_dict里不属于model_dict的键剔除掉
model_dict.update(pretrained_dict) # 更新现有的model_dict
model.load_state_dict(model_dict) # 加载我们真正需要的state_dict
有用就留个赞吧^_^
网友评论