一、pytorch多GPU并行
(1)引用库
import torch
(2)加载模型
model = XXX
(3) 并行化
# 检查可用GPU设备是否超过1台
if torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)
二、GPU数据转成list
(1)引用库
import numpy as np
(2) 获得GPU数据
target.data
# or
torch.max(logit, 1)[1].view(target.size()).data
(3)转成list
# 在代码后面加 .cpu().numpy()
target.data .cpu().numpy()
# or
torch.max(logit, 1)[1].view(target.size()).data .cpu().numpy()
网友评论