美文网首页
小黑的Python日记:Unet简单实现裂缝分割

小黑的Python日记:Unet简单实现裂缝分割

作者: 小黑的自我修养 | 来源:发表于2019-04-03 18:21 被阅读0次
    大噶好,我系小黑喵

    裂缝数据集

    数据集地址:https://github.com/cuilimeng/CrackForest-dataset
    结构:

      --project
        main.py
         --image
            --train
               --data
               --groundTruth
            --val
               --data
               --groundTruth
    

    我手动将数据集做成这个格式,其中trian84张,val34张,都保存为了jpg图像。

    Unet

    论文地址:http://www.arxiv.org/pdf/1505.04597.pdf
    代码来源:https://github.com/JavisPeng/u_net_liver
    上面代码中,作者将Unet运用于liver识别,和裂缝一样,都只有一个mask,因而我们可以直接使用上述代码。

    Unet结构

    需要修改dataset.py为自己的数据集,其他小小改动即可。

    #dataset.py
    import torch.utils.data as data
    import PIL.Image as Image
    import os
    
    
    def make_dataset(rootdata,roottarget):#获取img和mask的地址
        imgs = []
        filename_data = [x for x in os.listdir(rootdata)]
        for name in filename_data:
            img = os.path.join(rootdata, name)
            mask = os.path.join(roottarget, name)
            imgs.append((img, mask))#作为元组返回
        return imgs
    
    
    class MyDataset(data.Dataset):
        def __init__(self, rootdata, roottarget, transform=None, target_transform=None):
            imgs = make_dataset(rootdata,roottarget)
            self.imgs = imgs
            self.transform = transform
            self.target_transform = target_transform
    
        def __getitem__(self, index):
            x_path, y_path = self.imgs[index]
            img_x = Image.open(x_path).convert('L')#读取并转换为二值图像
            img_y = Image.open(y_path).convert('L')
            if self.transform is not None:
                img_x = self.transform(img_x)
            if self.target_transform is not None:
                img_y = self.target_transform(img_y)
            return img_x, img_y
    
        def __len__(self):
            return len(self.imgs)
    
    #main.py
    import numpy as np
    import torch
    import argparse
    from torch.utils.data import DataLoader
    from torch import autograd, optim
    from torchvision.transforms import transforms
    from unet import Unet
    from dataset import MyDataset
    
    # 是否使用cuda
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    x_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])  # 复活了,这里修改就没错误了
    ])
    
    # mask只需要转换为tensor
    y_transforms = transforms.ToTensor()
    
    
    def train_model(model, criterion, optimizer, dataload, num_epochs=10):
        for epoch in range(0,num_epochs):
            print('Epoch {}/{}'.format(epoch, num_epochs - 1))
            print('-' * 10)
            dt_size = len(dataload.dataset)
            epoch_loss = 0
            step = 0
            for x, y in dataload:
                step += 1
                inputs = x.to(device)
                labels = y.to(device)
                # zero the parameter gradients
                optimizer.zero_grad()
                # forward
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()
                print("%d/%d,train_loss:%0.3f" %
                      (step,
                       (dt_size - 1) // dataload.batch_size + 1, loss.item()))
            print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
        torch.save(model.cpu().state_dict(), 'weights_%d.pth' % epoch)
        return model
    
    
    #训练模型
    def train():
        batch_size = 1
        liver_dataset = MyDataset(
            "image/train/data", "image/train/gt",transform=x_transforms, target_transform=y_transforms)
        dataloaders = DataLoader(
            liver_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
        train_model(model, criterion, optimizer, dataloaders)
    
    
    #显示模型的输出结果
    def test():
        liver_dataset = MyDataset(
            "image/val/data", "image/val/gt", transform=x_transforms, target_transform=y_transforms)
        dataloaders = DataLoader(liver_dataset, batch_size=1)
        import matplotlib.pyplot as plt
        plt.ion()
        with torch.no_grad():
            for x, _ in dataloaders:
                y = model(x)
                img_y = torch.squeeze(y).numpy()
                plt.imshow(img_y)
                plt.pause(0.01)
            plt.show()
    
    
    if __name__ == '__main__':
        pretrained = False
        model = Unet(1, 1).to(device)
        if pretrained:
            model.load_state_dict(torch.load('./weights_4.pth'))
        criterion = torch.nn.BCELoss()
        optimizer = optim.Adam(model.parameters())
        train()
        test()
    

    unet.py不需要变动

    结果

    训练了10个epoch后:累加loss大概到3
    前几张预测图片:


    上为预测,下为groundTruth

    对于100多张的数据集,这个效果还行。
    也算是填了一个以前的坑。


    相关文章

      网友评论

          本文标题:小黑的Python日记:Unet简单实现裂缝分割

          本文链接:https://www.haomeiwen.com/subject/ljazbqtx.html