美文网首页
三维深度学习-多线程读取vtkImageData

三维深度学习-多线程读取vtkImageData

作者: 药柴 | 来源:发表于2018-09-10 16:44 被阅读0次

在深度学习最常用的卷积神经网络中,要求数据为具有空间局部性的多维矩阵或者说张量。这与广泛应用的三维模型格式例如STL这种保存三角面片的存储方式不一致。因此,采用体素化的方式对输入进行处理。

以VTK为例,在读入了vtkPolyData后,采用vtkPolyDataToImageStencilExample)的方式对三维模型进行转换,类似的转换方法还有vtkVoxelModeller,但相比之下效率极低。

不过,这样的方法还是较为缓慢,尤其是当输出体素模型规模较大时(如128x128x128),在实际使用中,会使模型文件读取占据了大量开销。不过,由于这个转换本身是可以重复利用的,因此在定义数据集时,加入了cache模式,PyTorch样例代码如下:

class Dataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None, cache=False):
        self.frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        if cache:
            self.cache = [None for i in range(len(self.frame))]
            for i in range(len(self.frame)):
                print('Caching record #%d\r' % (count))
                self.cache[count] = self.read(i)
        else:
            self.cache = None

    def __len__(self):
        return len(self.frame)

    def read(self, idx):
        """Read your data here."""
        return sample

    def __getitem__(self, idx):
        if self.cache:
            sample = self.cache[idx]
        else:
            sample = self.read(idx)
        if self.transform:
            sample = self.transform(sample)
        return sample

实践中发现这样建立缓存还是存在读取效率不足的问题,因此再次改写了一下,变成多线程的形式。

def __init__(self, csv_file, root_dir, transform=None, cache=False, thread=4):
    self.landmarks_frame = pd.read_csv(csv_file)
    self.root_dir = root_dir
    self.transform = transform
    if cache:
        self.cache = [None for i in range(len(self.landmarks_frame))]
        pool = multiprocessing.Pool(processes=thread)
        irange = range(len(self.landmarks_frame))
        count = 0
        for sample in pool.imap(self.read, irange):
            print('Caching record #%d\r' % (count))
            self.cache[count] = sample
            count += 1
    else:
        self.cache = None

可惜的是,这样的改写并不能成功,因为在multiprocessing中传递结果时用到了pickle进行数据的传递,而vtkImageData作为比较特殊的对象无法被pickle序列化。为了解决这个问题,简单调用了vtk.util.numpy_support里的一些方法,完成vtkImageData与Numpy array之间的无损转换。

def voxel2array(self, img):
    # Up to support for 3 dimensions for this line
    rows, cols, _ = img.GetDimensions()

    sc = img.GetPointData().GetScalars()
    arr = numpy_support.vtk_to_numpy(sc)
    arr = array.reshape(rows, cols, -1)
    spacing = img.GetSpacing()
    origin = img.GetOrigin()

    return arr, spacing, origin

def array2voxel(self, arr, spacing, origin):

    vtk_data = numpy_support.numpy_to_vtk(
        arr.ravel(), array_type=vtk.VTK_UNSIGNED_CHAR)
    img = vtk.vtkImageData()
    img.SetDimensions(array.shape)
    img.SetSpacing(spacing)
    img.SetOrigin(origin)
    img.GetPointData().SetScalars(vtk_data)

    return img

重点是vtkImageData中还留存着其体素的spacing信息和图像的整体坐标信息。
突然想到,在体素化前利用一些三维模型降采样方法对牙齿模型进行降采样,是否能够大大加速体素化。

相关文章

网友评论

      本文标题:三维深度学习-多线程读取vtkImageData

      本文链接:https://www.haomeiwen.com/subject/baisgftx.html