前言
Meta 前不久发布了图像分割通用模型Segment Anything
, 标志者图像分割领域的chatgpt
时刻来临,通用模型在图像处理领域表现出了强大的潜力,业内有人戏称 “Segment Anything 的出现,代表着曾经作为图像处理领域主流的图像分割任务,基本不存在了,分割已经没什么任务需要继续去做了^_^”。 Segment Anything
在自然图像分割领域,性能强大,但是,就医学图像分割 (如靶区勾画等)而言,Segment Anything
的局限性仍然显著存在,不过,Segment Anything
作为一个强大的工具,可以成为医学图像分割任务前处理或后处理的利器(如对分割的目标区域进行约束等)。以下,是一个 Segment Anything实现医学图像分割的简单demo,以供参考,基于此demo,可以发掘更多Segment Anything
更为有趣的玩法。
基本流程
- 图像转化
医学图像基本都是3D图像,而Segment Anything是基于自然图像训练而成,因此,无法直接对3D图像进行分割,所以,需要将3D 医学图像通过指定窗宽窗位转换为RGB3通道的一系列2D图像。
def transformNsCV(self):
minWindow = (2 * self.windowsCenter - self.windowsLevel) / 2.0 + 0.5
maxWindow = (2 * self.windowsCenter + self.windowsLevel) / 2.0 + 0.5
dFactor = 255.0 / (maxWindow - minWindow)
transArray = self.imgArray - minWindow
transArray = np.trunc(transArray * dFactor)
transArray[transArray > 255] = 255
transArray[transArray < 0] = 0
return transArray
- 根据先验种子点实现分割
Segment Anything 可以根据人为指定的初始种子点进行分割,类似于区域生长算法,实现一键抠图效果 。 种子点是一个 的numpy 数组, 具体如下
predictor = SamPredictor(self.model)
predictor.set_image(targetSlice)
seedsArr = np.array([[100,245],[500,345]]))
labels = [i+1 for i in range(0,len(seedsArr))]
masks, scores, logits = predictor.predict(
point_coords= seedsArr,
point_labels= np.array(labels),
multimask_output=True,
)
- 全图作为分割
Segment Anything 也支持对全图做分割,但是速度较慢
def pipelineSegAllImage(self,targetSlice:np.ndarray):
maskSliceRes = np.zeros_like(targetSlice[:,:,0])
mask_generator = SamAutomaticMaskGenerator(self.mdoel)
targetSlice = targetSlice.astype(np.uint8)
masks = mask_generator.generate(targetSlice)
for mask in masks:
maskArrayBool = mask["segmentation"]
maskArray = maskArrayBool.astype(np.uint8)
maskSliceRes = maskSliceRes + maskArray
return maskSliceRes
其中一张slice的实现效果如下,由于是通过2D实现的分割,所以最终效果可能在z方向上的连续性不够好。
综上,利用Segment Anything实现医学图像分割的全流程代码如下
from segment_anything import sam_model_registry
from segment_anything import SamPredictor,SamAutomaticMaskGenerator
import SimpleITK as sitk
import numpy as np
import torch as t
import cv2
import matplotlib.pyplot as plt
from itertools import product
import random
import sys
from ReadAndWrite import ReadImageBase
class SamMedSegmentation():
def __init__(self,imgArray:np.ndarray,windowsCenter:int,windowsLevel:int,seedlists=None):
self.imgArray = imgArray
assert(len(self.imgArray.shape)) == 3
if seedlists == None:
self.seedlists = self.generateSeeds(1)
else:
self.seedlists = seedlists
self.windowsCenter = windowsCenter
self.windowsLevel = windowsLevel
self.model = self.initializeModel()
def initializeModel(self):
modelType = "vit_l"
checkPoint = r"/mnt/e/ChromeDwnLoad/sam_vit_l_0b3195.pth"
sam = sam_model_registry[modelType](checkpoint=checkPoint)
sam.to("cuda")
return sam
def generateSeeds(self,maxIter):
z,y,x = self.imgArray.shape
px = (int(x/8),int(x/7),int(x/6),int(x/4),int(x/3),int(x/2))
py = (int(y/8),int(y/7),int(y/5),int(y/4),int(y/3),int(y/2))
points = list(product(py,px))
pointsNum = list(map(lambda x: np.array(x),points))
point_xy = list()
for i in range(0,maxIter):
points_random = random.choice(pointsNum)
#print([points_random[0], points_random[1]])
point_x = points_random[0]
point_y = points_random[1]
point_xy.append([point_x,point_y])
return point_xy
def transformNsCV(self):
minWindow = (2 * self.windowsCenter - self.windowsLevel) / 2.0 + 0.5
maxWindow = (2 * self.windowsCenter + self.windowsLevel) / 2.0 + 0.5
dFactor = 255.0 / (maxWindow - minWindow)
transArray = self.imgArray - minWindow
transArray = np.trunc(transArray * dFactor)
transArray[transArray > 255] = 255
transArray[transArray < 0] = 0
return transArray
def convertMedImg2CV(self,targetArray:np.ndarray):
assert len(targetArray.shape) == 3
z,y,x = targetArray.shape
imglist = []
for z_i in range(0,z):
targetSlice = targetArray[z_i,:,:]
targetSliceRGB = cv2.cvtColor(targetSlice,cv2.COLOR_GRAY2RGB)
#print(type(targetSliceRGB))
imglist.append(targetSliceRGB)
return imglist
def pipelineSegAllImage(self,targetSlice:np.ndarray):
maskSliceRes = np.zeros_like(targetSlice[:,:,0])
mask_generator = SamAutomaticMaskGenerator(self.mdoel)
targetSlice = targetSlice.astype(np.uint8)
masks = mask_generator.generate(targetSlice)
for mask in masks:
maskArrayBool = mask["segmentation"]
maskArray = maskArrayBool.astype(np.uint8)
maskSliceRes = maskSliceRes + maskArray
return maskSliceRes
def pipelineSegPoint(self,targetSlice:np.ndarray):
targetSlice = targetSlice.astype(np.uint8)
maskSliceRes = np.zeros_like(targetSlice[:,:,0])
predictor = SamPredictor(self.model)
predictor.set_image(targetSlice)
seeds = self.generateSeeds(4)
seedsArr = np.array(seeds)
print(seedsArr.shape)
#print(seeds)
labels = [i+1 for i in range(0,len(seedsArr))]
masks, scores, logits = predictor.predict(
point_coords= seedsArr,
point_labels= np.array(labels),
multimask_output=True,
)
for idx,mask in enumerate(masks):
mask = mask.astype(np.uint8)
maskSliceRes += mask
return maskSliceRes
def processPipeline(self,all=False):
targetArray = self.transformNsCV()
imgslists = self.convertMedImg2CV(targetArray)
if False == all:
maskSlicesRes = list(map(self.pipelineSegPoint,imgslists))
else:
maskSlicesRes = list(map(self.pipelineSegAllImage,imgslists))
maskResults = np.zeros_like(self.imgArray)
for idx,mask in enumerate(maskSlicesRes):
maskResults[idx,:,:] = mask
return maskResults
if __name__ == '__main__':
imgpath = r"./0522c0149"
outpath = r"./out"
reader = ReadImageBase(imgpath,outpath,'.nrrd')
#fileLen = len(reader)
imgArray,detail = reader[0]
seg = SamMedSegmentation(imgArray,50,350)
maskResults = seg.processPipeline(all=False)
print(maskResults)
print(np.max(maskResults))
maskResults = maskResults.astype(np.uint8)
reader.writer(maskResults,detail,"segPoint.nrrd",outPath = outpath)
网友评论