简单记录做cs231n作业学到的pytorch的小技巧
transform的输入必须是PIL image,并且一次只支持一张图像的增强
def preprocess(img, size=224):
transform = T.Compose([
T.Resize(size),
T.ToTensor(),
T.Normalize(mean=SQUEEZENET_MEAN.tolist(),
std=SQUEEZENET_STD.tolist()),
T.Lambda(lambda x: x[None]), # ???
])
return transform(img)
torch gather用法
- 要求s, y都是二维tensor或者三维tensor,shape可以不同
# Example of using gather to select one entry from each row in PyTorch
def gather_example():
N, C = 4, 5
s = torch.randn(N, C)
y = torch.LongTensor([1, 2, 1, 3])
print(s)
print(y)
print(s.gather(1, y.view(-1, 1)).squeeze())
# print(s.gather(1, y.view(-1, 1)).squeeze())
gather_example()
torch max 的常见用法
x = torch.randn(1,2)
print(x.max(1)) # max(values, index)二元组
print()
print(x.max(1)[0], x.max(1)[1])
# 输出
"""
torch.return_types.max(
values=tensor([-0.0182]),
indices=tensor([0]))
tensor([-0.0182]) tensor([0])"""
# 想要得到index 直接用.argmax(dim=1)
a = torch.arange(0, 6).view(2,3)
print(a)
a.argmax(1)
PIL.Image.fromarray 用法
[图片上传失败...(image-b39634-1575038993735)]
PIL.Image.open 用法
[图片上传失败...(image-8edb96-1575038993735)]
torch 需要求导的变量的原地操作
x.data.copy_(jitter(img.data, ox, oy))
需要使用x.data.copy_()
方法实现
网友评论