美文网首页程序员想法
Torch深度学习框架中一些常用API,BP反向传播通用代码

Torch深度学习框架中一些常用API,BP反向传播通用代码

作者: 千与编程 | 来源:发表于2020-05-24 18:03 被阅读0次

Torchvision是独立于PyTorch的关于图像操作的一个工具库,目前包括六个模块:

torchvision.datasets:Torch框架中常用的数据集集成与使用

torchvision.models:经典模型积累,torchvision.models.resnet18

torchvision.transforms:常用的图像操作,例随机切割、旋转、数据类型转换、tensor与numpy 和PIL Image的互换等

torchvision.utils:其他工具,比如产生一个图像网格等

from torch import transform

torchvision中常用的数据扩增方法

transforms.CenterCrop 对图片中心进行裁剪

transforms.ColorJitter 对图像颜色的对比度、饱和度和零度进行变换

transforms.FiveCrop 对图像四个角和中心进行裁剪得到五分图像

transforms.Grayscale 对图像进行灰度变换

transforms.Pad 使用固定值进行像素填充

transforms.RandomAffine 随机仿射变换

transforms.RandomCrop 随机区域裁剪

transforms.RandomHorizontalFlip 随机水平翻转

transforms.RandomRotation 随机旋转

transforms.RandomVerticalFlip 随机垂直翻转

import sys

sys-系统特定的参数和功能

torch.manual_seed(args.seed) #为CPU设置种子用于生成随机数,以使得结果是确定的

if args.cuda:

torch.cuda.manual_seed(args.seed)#为当前GPU设置随机种子;如果使用多个GPU,应该使用torch.cuda.manual_seed_all()为所有的GPU设置种子。

from keras.layers import UpSampling2D

UpSampling2D为pooling2D的相反,上采样层

BP反向传播通用代码:

使用Torch框架的进行模型训练的BP反向传播代码:

model=Model(input)

criterion = nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

#模型训练

for e in range(1000):

var_x = Variable(train_x)

var_y = Variable(train_y)

out = model(var_x)# 前向传播

loss = criterion(out, var_y)

optimizer.zero_grad()

loss.backward()

optimizer.step()

if (e + 1) % 100 == 0: # 每 100 次输出结果

print(‘Epoch: {}, Loss: {:.5f}’.format(e + 1,loss.data[0]))

model.eval()

data_X = data_X.reshape(-1, 1, 2)

data_X = torch.from_numpy(data_X)

var_data = Variable(data_X)

pred_test = model(var_data) # 测试集的预测结果

pred_test = pred_test.view(-1).data.numpy()

matplotlib.pyplot的API说明

plt.legend()#绘图图例

import numpy as np

import pandas as pd

data_csv = data_csv.dropna() # 滤除缺失数据

dataset = data_csv.values # 获得csv的值

dataset = dataset.astype(‘float32’)

max_value = np.max(dataset) # 获得最大值

min_value = np.min(dataset) # 获得最小值

scalar = max_value - min_value # 获得间隔数量

相关文章

网友评论

    本文标题:Torch深度学习框架中一些常用API,BP反向传播通用代码

    本文链接:https://www.haomeiwen.com/subject/sxejahtx.html