数据预处理
下载数据集放置在对应文件夹
|-- TestDataset
| |-- CVC-300
| |-- CVC-ClinicDB
| |-- CVC-ColonDB
| |-- ETIS-LaribPolypDB
| |-- Kvasir
| |-- images
| |-- masks
环境配置
需要预先安装好Anaconda和Python环境,建议Python3.8
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
conda install -c conda-forge imageio
pip install scikit-image timm timm pyyaml Pillow numpy matplotlib
启动测试
- 模型结构及测试数据配置
datasetTest = ['Kvasir', 'CVC-ColonDB', 'CVC-ClinicDB', 'ETIS-LaribPolypDB', 'CVC-300']
_model_name = 'ESFP_B4_Endo_Best_Balance'
config = yaml.safe_load(open('Configure.yaml'))
init_trainsize = 352
- 数据加载
class test_dataset:
def __init__(self, image_root, gt_root, testsize):
self.testsize = testsize
self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.tif') or f.endswith('.png') or f.endswith('.jpg')]
self.images = sorted(self.images)
self.gts = sorted(self.gts)
self.transform = transforms.Compose([
transforms.Resize((self.testsize, self.testsize)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
self.gt_transform = transforms.ToTensor()
self.size = len(self.images)
self.index = 0
def load_data(self):
image = self.rgb_loader(self.images[self.index])
image = self.transform(image).unsqueeze(0)
gt = self.binary_loader(self.gts[self.index])
name = self.images[self.index].split('/')[-1]
if name.endswith('.jpg'):
name = name.split('.jpg')[0] + '.png'
self.index += 1
return image, gt, name
def rgb_loader(self, path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def binary_loader(self, path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('L')
- 网络结构定义
from Encoder import mit
from Decoder import mlp
from mmcv.cnn import ConvModule
class ESFPNetStructure(nn.Module):
def __init__(self, embedding_dim = 160):
super(ESFPNetStructure, self).__init__()
# Backbone
if model_type == 'B0':
self.backbone = mit.mit_b0()
if model_type == 'B1':
self.backbone = mit.mit_b1()
if model_type == 'B2':
self.backbone = mit.mit_b2()
if model_type == 'B3':
self.backbone = mit.mit_b3()
if model_type == 'B4':
self.backbone = mit.mit_b4()
if model_type == 'B5':
self.backbone = mit.mit_b5()
self._init_weights() # load pretrain
# LP Header
self.LP_1 = mlp.LP(input_dim = self.backbone.embed_dims[0], embed_dim = self.backbone.embed_dims[0])
self.LP_2 = mlp.LP(input_dim = self.backbone.embed_dims[1], embed_dim = self.backbone.embed_dims[1])
self.LP_3 = mlp.LP(input_dim = self.backbone.embed_dims[2], embed_dim = self.backbone.embed_dims[2])
self.LP_4 = mlp.LP(input_dim = self.backbone.embed_dims[3], embed_dim = self.backbone.embed_dims[3])
# Linear Fuse
self.linear_fuse34 = ConvModule(in_channels=(self.backbone.embed_dims[2] + self.backbone.embed_dims[3]), out_channels=self.backbone.embed_dims[2], kernel_size=1,norm_cfg=dict(type='BN', requires_grad=True))
self.linear_fuse23 = ConvModule(in_channels=(self.backbone.embed_dims[1] + self.backbone.embed_dims[2]), out_channels=self.backbone.embed_dims[1], kernel_size=1,norm_cfg=dict(type='BN', requires_grad=True))
self.linear_fuse12 = ConvModule(in_channels=(self.backbone.embed_dims[0] + self.backbone.embed_dims[1]), out_channels=self.backbone.embed_dims[0], kernel_size=1,norm_cfg=dict(type='BN', requires_grad=True))
# Fused LP Header
self.LP_12 = mlp.LP(input_dim = self.backbone.embed_dims[0], embed_dim = self.backbone.embed_dims[0])
self.LP_23 = mlp.LP(input_dim = self.backbone.embed_dims[1], embed_dim = self.backbone.embed_dims[1])
self.LP_34 = mlp.LP(input_dim = self.backbone.embed_dims[2], embed_dim = self.backbone.embed_dims[2])
# Final Linear Prediction
self.linear_pred = nn.Conv2d((self.backbone.embed_dims[0] + self.backbone.embed_dims[1] + self.backbone.embed_dims[2] + self.backbone.embed_dims[3]), 1, kernel_size=1)
def _init_weights(self):
if model_type == 'B0':
pretrained_dict = torch.load('./Pretrained/mit_b0.pth')
if model_type == 'B1':
pretrained_dict = torch.load('./Pretrained/mit_b1.pth')
if model_type == 'B2':
pretrained_dict = torch.load('./Pretrained/mit_b2.pth')
if model_type == 'B3':
pretrained_dict = torch.load('./Pretrained/mit_b3.pth')
if model_type == 'B4':
pretrained_dict = torch.load('./Pretrained/mit_b4.pth')
if model_type == 'B5':
pretrained_dict = torch.load('./Pretrained/mit_b5.pth')
model_dict = self.backbone.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
self.backbone.load_state_dict(model_dict)
print("successfully loaded!!!!")
def forward(self, x):
################## Go through backbone ###################
B = x.shape[0]
#stage 1
out_1, H, W = self.backbone.patch_embed1(x)
for i, blk in enumerate(self.backbone.block1):
out_1 = blk(out_1, H, W)
out_1 = self.backbone.norm1(out_1)
out_1 = out_1.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() #(Batch_Size, self.backbone.embed_dims[0], 88, 88)
# stage 2
out_2, H, W = self.backbone.patch_embed2(out_1)
for i, blk in enumerate(self.backbone.block2):
out_2 = blk(out_2, H, W)
out_2 = self.backbone.norm2(out_2)
out_2 = out_2.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() #(Batch_Size, self.backbone.embed_dims[1], 44, 44)
# stage 3
out_3, H, W = self.backbone.patch_embed3(out_2)
for i, blk in enumerate(self.backbone.block3):
out_3 = blk(out_3, H, W)
out_3 = self.backbone.norm3(out_3)
out_3 = out_3.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() #(Batch_Size, self.backbone.embed_dims[2], 22, 22)
# stage 4
out_4, H, W = self.backbone.patch_embed4(out_3)
for i, blk in enumerate(self.backbone.block4):
out_4 = blk(out_4, H, W)
out_4 = self.backbone.norm4(out_4)
out_4 = out_4.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() #(Batch_Size, self.backbone.embed_dims[3], 11, 11)
# go through LP Header
lp_1 = self.LP_1(out_1)
lp_2 = self.LP_2(out_2)
lp_3 = self.LP_3(out_3)
lp_4 = self.LP_4(out_4)
# linear fuse and go pass LP Header
lp_34 = self.LP_34(self.linear_fuse34(torch.cat([lp_3, F.interpolate(lp_4,scale_factor=2,mode='bilinear', align_corners=False)], dim=1)))
lp_23 = self.LP_23(self.linear_fuse23(torch.cat([lp_2, F.interpolate(lp_34,scale_factor=2,mode='bilinear', align_corners=False)], dim=1)))
lp_12 = self.LP_12(self.linear_fuse12(torch.cat([lp_1, F.interpolate(lp_23,scale_factor=2,mode='bilinear', align_corners=False)], dim=1)))
# get the final output
lp4_resized = F.interpolate(lp_4,scale_factor=8,mode='bilinear', align_corners=False)
lp3_resized = F.interpolate(lp_34,scale_factor=4,mode='bilinear', align_corners=False)
lp2_resized = F.interpolate(lp_23,scale_factor=2,mode='bilinear', align_corners=False)
lp1_resized = lp_12
out = self.linear_pred(torch.cat([lp1_resized, lp2_resized, lp3_resized, lp4_resized], dim=1))
out_resized = F.interpolate(out,scale_factor=4,mode='bilinear', align_corners=True)
return out_resized
- 主函数(函数调用、后处理等)
def main():
for _data_name in datasetTest:
save_path = './results/{}/{}/'.format(_model_name,_data_name)
os.makedirs(save_path, exist_ok=True)
test_loader = test_dataset(config['dataset']['test_' + str(_data_name) + '_img'], config['dataset']['test_' + str(_data_name) + '_label'], 352)
model_path = './SaveModel/{}'.format(_model_name)
ESFPNetBest = torch.load(model_path + '/ESFPNet.pt')
ESFPNetBest.eval()
for i in range(test_loader.size):
image, gt, name = test_loader.load_data()
gt = np.asarray(gt, np.float32)
gt /= (gt.max() + 1e-8)
image = image.cuda()
pred = ESFPNetBest(image)
pred = F.upsample(pred, size=gt.shape, mode='bilinear', align_corners=False)
pred = pred.sigmoid()
threshold = torch.tensor([0.5]).to(device)
pred = (pred > threshold).float() * 1
pred = pred.data.cpu().numpy().squeeze()
pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
imageio.imwrite(save_path+name,img_as_ubyte(pred))
网友评论