之前利用PIL把dicom的slice保存为了16位灰度图, 用torchvision.transform
做图像增强时发现会报错.
Dataset
的__getitem__
函数如下
def __getitem__(self, idx):
pth = self.dataset[idx]
img = Image.open(pth) # 范围为 [0-2048] 的16位tiff图片
img = torchvision.transform.ToTensor()(img)
return img
output:
RuntimeError: shape '[64, 64, 5]' is invalid for input of size 8192
查询了一下torchvision.transform.ToTensor()
函数, 发现对输入值域要求为[0-255]
. 估计是我[0-2048]
的范围出发了某种判断, 使得该函数以为输入图片为某种其他格式.
将函数改为如下, 解决了问题
def __getitem__(self, idx):
pth = self.dataset[idx]
img = np.array(Image.open(pth), dtype='float32') / 2048 # 范围为 [0-1] 的单精度`numpy`数组
img = Image.fromarray(img)
img = torchvision.transform.ToTensor()(img)
return img
注意, 若不显式注明dtype='float32'
, 会自动转换为float64
的tensor
, 不确定对训练结果和速度有何影响 (pytorch
的默认数据类型为float32
).
总之, 下次直接把图片保存为numpy
格式会更方便些.
网友评论