美文网首页
利用 Segment Anything实现医学图像分割

利用 Segment Anything实现医学图像分割

作者: 此间不留白 | 来源:发表于2023-04-11 22:54 被阅读0次

前言

Meta 前不久发布了图像分割通用模型Segment Anything, 标志者图像分割领域的chatgpt时刻来临,通用模型在图像处理领域表现出了强大的潜力,业内有人戏称 “Segment Anything 的出现,代表着曾经作为图像处理领域主流的图像分割任务,基本不存在了,分割已经没什么任务需要继续去做了^_^”。 Segment Anything 在自然图像分割领域,性能强大,但是,就医学图像分割 (如靶区勾画等)而言,Segment Anything 的局限性仍然显著存在,不过,Segment Anything 作为一个强大的工具,可以成为医学图像分割任务前处理或后处理的利器(如对分割的目标区域进行约束等)。以下,是一个 Segment Anything实现医学图像分割的简单demo,以供参考,基于此demo,可以发掘更多Segment Anything更为有趣的玩法。

基本流程

  • 图像转化
    医学图像基本都是3D图像,而Segment Anything是基于自然图像训练而成,因此,无法直接对3D图像进行分割,所以,需要将3D 医学图像通过指定窗宽窗位转换为RGB3通道的一系列2D图像。
def transformNsCV(self):
        minWindow = (2 * self.windowsCenter - self.windowsLevel) / 2.0 + 0.5
        maxWindow = (2 * self.windowsCenter + self.windowsLevel) / 2.0 + 0.5
        dFactor = 255.0 / (maxWindow - minWindow)

        transArray = self.imgArray - minWindow
        transArray = np.trunc(transArray * dFactor)
        
        transArray[transArray > 255] = 255
        transArray[transArray < 0] = 0

        return transArray

  • 根据先验种子点实现分割
    Segment Anything 可以根据人为指定的初始种子点进行分割,类似于区域生长算法,实现一键抠图效果 。 种子点是一个 N \times 2 的numpy 数组, 具体如下
predictor = SamPredictor(self.model)
predictor.set_image(targetSlice)
        
seedsArr = np.array([[100,245],[500,345]]))    
labels = [i+1 for i in range(0,len(seedsArr))]
masks, scores, logits = predictor.predict(
 point_coords= seedsArr,
 point_labels= np.array(labels),
multimask_output=True,
        )


  • 全图作为分割
    Segment Anything 也支持对全图做分割,但是速度较慢
def pipelineSegAllImage(self,targetSlice:np.ndarray):
        
        maskSliceRes = np.zeros_like(targetSlice[:,:,0])
        
        mask_generator = SamAutomaticMaskGenerator(self.mdoel)
        targetSlice = targetSlice.astype(np.uint8)
        masks = mask_generator.generate(targetSlice)
        for mask in masks:
            maskArrayBool = mask["segmentation"]
            maskArray = maskArrayBool.astype(np.uint8)
            maskSliceRes = maskSliceRes + maskArray

        return maskSliceRes

其中一张slice的实现效果如下,由于是通过2D实现的分割,所以最终效果可能在z方向上的连续性不够好。


综上,利用Segment Anything实现医学图像分割的全流程代码如下


from segment_anything import sam_model_registry
from segment_anything import SamPredictor,SamAutomaticMaskGenerator
import SimpleITK as sitk
import numpy as np
import torch as t
import cv2
import matplotlib.pyplot as plt
from itertools import product
import random
import sys
from ReadAndWrite import ReadImageBase

class SamMedSegmentation():
    def __init__(self,imgArray:np.ndarray,windowsCenter:int,windowsLevel:int,seedlists=None):
        
        self.imgArray = imgArray
        assert(len(self.imgArray.shape)) == 3
        if seedlists == None:
            self.seedlists = self.generateSeeds(1)

        else:
            self.seedlists = seedlists

        self.windowsCenter = windowsCenter
        self.windowsLevel = windowsLevel
        
        self.model = self.initializeModel()

    def initializeModel(self):
        modelType = "vit_l"
        checkPoint = r"/mnt/e/ChromeDwnLoad/sam_vit_l_0b3195.pth"
        sam = sam_model_registry[modelType](checkpoint=checkPoint)
        sam.to("cuda")

        return sam


    def generateSeeds(self,maxIter):
        z,y,x = self.imgArray.shape
        px = (int(x/8),int(x/7),int(x/6),int(x/4),int(x/3),int(x/2))
        py = (int(y/8),int(y/7),int(y/5),int(y/4),int(y/3),int(y/2))

        points = list(product(py,px))
        pointsNum = list(map(lambda x: np.array(x),points))
        point_xy = list()
        for i in range(0,maxIter):
            points_random = random.choice(pointsNum)
        #print([points_random[0], points_random[1]])
            point_x = points_random[0]
            point_y = points_random[1]
            point_xy.append([point_x,point_y])
        return point_xy
    

    def transformNsCV(self):
        minWindow = (2 * self.windowsCenter - self.windowsLevel) / 2.0 + 0.5
        maxWindow = (2 * self.windowsCenter + self.windowsLevel) / 2.0 + 0.5
        dFactor = 255.0 / (maxWindow - minWindow)

        transArray = self.imgArray - minWindow
        transArray = np.trunc(transArray * dFactor)
        
        transArray[transArray > 255] = 255
        transArray[transArray < 0] = 0

        return transArray




    def convertMedImg2CV(self,targetArray:np.ndarray):
        assert len(targetArray.shape) == 3
        z,y,x = targetArray.shape
        imglist = []
        for z_i in range(0,z):
            targetSlice = targetArray[z_i,:,:]
            targetSliceRGB = cv2.cvtColor(targetSlice,cv2.COLOR_GRAY2RGB)
            #print(type(targetSliceRGB))
            imglist.append(targetSliceRGB)
        return imglist


    def pipelineSegAllImage(self,targetSlice:np.ndarray):
        
        maskSliceRes = np.zeros_like(targetSlice[:,:,0])
        
        mask_generator = SamAutomaticMaskGenerator(self.mdoel)
        targetSlice = targetSlice.astype(np.uint8)
        masks = mask_generator.generate(targetSlice)
        for mask in masks:
            maskArrayBool = mask["segmentation"]
            maskArray = maskArrayBool.astype(np.uint8)
            maskSliceRes = maskSliceRes + maskArray

        return maskSliceRes
    

    def pipelineSegPoint(self,targetSlice:np.ndarray):
        targetSlice  = targetSlice.astype(np.uint8)
        maskSliceRes = np.zeros_like(targetSlice[:,:,0])
        
        predictor = SamPredictor(self.model)
        predictor.set_image(targetSlice)
        
        
        seeds = self.generateSeeds(4)
        seedsArr = np.array(seeds)
        print(seedsArr.shape)    
        #print(seeds)
        labels = [i+1 for i in range(0,len(seedsArr))]
        masks, scores, logits = predictor.predict(
        point_coords= seedsArr,
        point_labels= np.array(labels),
        multimask_output=True,
        )

        for idx,mask in enumerate(masks):
            mask = mask.astype(np.uint8)
            maskSliceRes += mask

        return maskSliceRes


    

    def processPipeline(self,all=False):
        targetArray = self.transformNsCV()
        imgslists = self.convertMedImg2CV(targetArray)
        if False == all:
            maskSlicesRes = list(map(self.pipelineSegPoint,imgslists))
        else:

            maskSlicesRes = list(map(self.pipelineSegAllImage,imgslists))
        maskResults = np.zeros_like(self.imgArray)
        for idx,mask in enumerate(maskSlicesRes):
            maskResults[idx,:,:] = mask

        
        return maskResults
    


if __name__ == '__main__':
    
    imgpath = r"./0522c0149"
    outpath = r"./out"
    reader = ReadImageBase(imgpath,outpath,'.nrrd')
    #fileLen = len(reader)
    imgArray,detail = reader[0]


    seg = SamMedSegmentation(imgArray,50,350)
    maskResults = seg.processPipeline(all=False)
    print(maskResults)
    print(np.max(maskResults))
    maskResults = maskResults.astype(np.uint8)
    reader.writer(maskResults,detail,"segPoint.nrrd",outPath = outpath)


相关文章

  • 基于caffe的FCN图像分割(一)

    前言 在计算视觉领域,除了图像分类,目标检测,目标跟踪之外,图像分割也是研究的热点之一。 图像分割的常用医学图像,...

  • 【图像分割应用】医学图像分割(三)——肿瘤分割

    这是专栏《图像分割应用》的第3篇文章,本专栏主要介绍图像分割在各个领域的应用、难点、技术要求等常见问题。 肿瘤的分...

  • 【图像分割应用】医学图像分割(二)——心脏分割

    这是专栏《图像分割应用》的第2篇文章,本专栏主要介绍图像分割在各个领域的应用、难点、技术要求等常见问题。 相比较脑...

  • 医学图像分割论文

    1. 2019 综述阅读 医学+深度 2019--Going Deep in Medical Image An...

  • 医学图像处理

    一、图像分割 图像分割是前期的工作重点,主要使用了现成的软件来完成图像分割任务:3DMed(中国科学院自动化医学图...

  • 【图像分割应用】医学图像分割(一)——脑区域分割

    从本周开始,新专栏《图像分割应用》就跟大家见面了。本专栏主要介绍图像分割在各个领域的应用、难点、技术要求等常见问题...

  • 医学图像分割及应用

    截至目前,我们已经学习了很多关于图像分割的相关算法,就此,对图像的分割算法做以下总结: 基于边界驱动的分割边缘检测...

  • 衡量语义分割准确率

    1 Dice Score 是医学图像分割结果衡量的常用指标。代表的是ground truth的分割结果,Vpr代表...

  • 图像的能量表达

    上海交通大学 医学图像处理技术 前言 医学图像的分割,若是全依赖人类手工完成,不仅速度慢,而且还容易受到人类医师的...

  • 语义图像分割概览

    摘要:本文讨论如何利用卷积神经网络进行语义图像分割的任务。 语义图像分割,目标是将图像的每个像素标记为所表示的相关...

网友评论

      本文标题:利用 Segment Anything实现医学图像分割

      本文链接:https://www.haomeiwen.com/subject/ofrnddtx.html