代码链接:https://github.com/LeeJunHyun/Image_Segmentation
main.py
if name == 'main':
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# model hyper-parameters
parser.add_argument('--image_size', type=int, default=224)
parser.add_argument('--t', type=int, default=3, help='t for Recurrent step of R2U_Net or R2AttU_Net')
# training hyper-parameters
parser.add_argument('--img_ch', type=int, default=3)
parser.add_argument('--output_ch', type=int, default=1)
parser.add_argument('--num_epochs', type=int, default=100)
parser.add_argument('--num_epochs_decay', type=int, default=70)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--lr', type=float, default=0.0002)
parser.add_argument('--beta1', type=float, default=0.5) # momentum1 in Adam
parser.add_argument('--beta2', type=float, default=0.999) # momentum2 in Adam
parser.add_argument('--augmentation_prob', type=float, default=0.4)
parser.add_argument('--log_step', type=int, default=2)
parser.add_argument('--val_step', type=int, default=2)
# misc
parser.add_argument('--mode', type=str, default='train')
parser.add_argument('--model_type', type=str, default='U_Net', help='U_Net/R2U_Net/AttU_Net/R2AttU_Net')
parser.add_argument('--model_path', type=str, default='./models')
parser.add_argument('--train_path', type=str, default='./dataset/train/')
parser.add_argument('--valid_path', type=str, default='./dataset/valid/')
parser.add_argument('--test_path', type=str, default='./dataset/test/')
parser.add_argument('--result_path', type=str, default='./result/')
parser.add_argument('--cuda_idx', type=int, default=1)
config = parser.parse_args()
main(config)
argparse是一个Python模块:命令行选项、参数和子命令解析器。
argparse
模块可以让人轻松编写用户友好的命令行接口。程序定义它需要的参数,然后 argparse
将弄清如何从 sys.argv
解析出那些参数。 argparse
模块还会自动生成帮助和使用手册,并在用户给程序传入无效参数时报出错误信息。
- 创建解析器
parser = argparse.ArgumentParser(description='Process some integers.')
使用 argparse
的第一步是创建一个 ArgumentParser
对象。 ArgumentParser
对象包含将命令行解析成 Python 数据类型所需的全部信息。
- 添加参数
parser.add_argument('integers', metavar='N', type=int, nargs='+', help='an integer for the accumulator')
- 解析参数
ArgumentParser
通过parse_args()
方法解析参数。
def main
def main(config):
cudnn.benchmark = True
if config.model_type not in ['U_Net','R2U_Net','AttU_Net','R2AttU_Net']:
print('ERROR!! model_type should be selected in U_Net/R2U_Net/AttU_Net/R2AttU_Net')
print('Your input for model_type was %s'%config.model_type)
return
# Create directories if not exist
if not os.path.exists(config.model_path):
os.makedirs(config.model_path)
if not os.path.exists(config.result_path):
os.makedirs(config.result_path)
config.result_path = os.path.join(config.result_path,config.model_type)
if not os.path.exists(config.result_path):
os.makedirs(config.result_path)
lr = random.random()*0.0005 + 0.0000005
augmentation_prob= random.random()*0.7
epoch = random.choice([100,150,200,250])
decay_ratio = random.random()*0.8
decay_epoch = int(epoch*decay_ratio)
config.augmentation_prob = augmentation_prob
config.num_epochs = epoch
config.lr = lr
config.num_epochs_decay = decay_epoch
print(config)
train_loader = get_loader(image_path=config.train_path,
image_size=config.image_size,
batch_size=config.batch_size,
num_workers=config.num_workers,
mode='train',
augmentation_prob=config.augmentation_prob)
valid_loader = get_loader(image_path=config.valid_path,
image_size=config.image_size,
batch_size=config.batch_size,
num_workers=config.num_workers,
mode='valid',
augmentation_prob=0.)
test_loader = get_loader(image_path=config.test_path,
image_size=config.image_size,
batch_size=config.batch_size,
num_workers=config.num_workers,
mode='test',
augmentation_prob=0.)
solver = Solver(config, train_loader, valid_loader, test_loader)
# Train and sample the images
if config.mode == 'train':
solver.train()
elif config.mode == 'test':
solver.test()
-
random.random()用于生成一个0到1的随机符点数: 0 <= n < 1.0
-
学习率衰减(learning rate decay)
为了防止学习率过大,在收敛到全局最优点的时候会来回摆荡,所以要让学习率随着训练轮数不断按指数级下降,收敛梯度下降的学习步长。
data_loader.py
get_loader
def get_loader(image_path, image_size, batch_size, num_workers=2, mode='train',augmentation_prob=0.4):
"""Builds and returns Dataloader."""
dataset = ImageFolder(root = image_path, image_size =image_size, mode=mode,augmentation_prob=augmentation_prob)
data_loader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers)
return data_loader
- DataLoader:在训练模型时使用到此函数,用来把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。就是做一个数据的初始化。
- num_workers:工作者数量,默认是0。使用多少个子进程来导入数据。设置为0,就是使用主进程来导入数据。注意:这个数字必须是大于等于0的。
ImageFolder
class ImageFolder(data.Dataset):
def __init__(self, root,image_size=224,mode='train',augmentation_prob=0.4):
"""Initializes image paths and preprocessing module."""
self.root = root
# GT : Ground Truth
self.GT_paths = root[:-1]+'_GT/'
self.image_paths = list(map(lambda x: os.path.join(root, x), os.listdir(root)))
self.image_size = image_size
self.mode = mode
self.RotationDegree = [0,90,180,270]
self.augmentation_prob = augmentation_prob
print("image count in {} path :{}".format(self.mode,len(self.image_paths)))
def __getitem__(self, index):
"""Reads an image from a file and preprocesses it and returns."""
image_path = self.image_paths[index]
filename = image_path.split('_')[-1][:-len(".jpg")]
GT_path = self.GT_paths + 'ISIC_' + filename + '_segmentation.png'
image = Image.open(image_path)
GT = Image.open(GT_path)
aspect_ratio = image.size[1]/image.size[0]
Transform = []
ResizeRange = random.randint(300,320)
Transform.append(T.Resize((int(ResizeRange*aspect_ratio),ResizeRange)))
p_transform = random.random()
if (self.mode == 'train') and p_transform <= self.augmentation_prob:
RotationDegree = random.randint(0,3)
RotationDegree = self.RotationDegree[RotationDegree]
if (RotationDegree == 90) or (RotationDegree == 270):
aspect_ratio = 1/aspect_ratio
Transform.append(T.RandomRotation((RotationDegree,RotationDegree)))
RotationRange = random.randint(-10,10)
Transform.append(T.RandomRotation((RotationRange,RotationRange)))
CropRange = random.randint(250,270)
Transform.append(T.CenterCrop((int(CropRange*aspect_ratio),CropRange)))
Transform = T.Compose(Transform)
image = Transform(image)
GT = Transform(GT)
ShiftRange_left = random.randint(0,20)
ShiftRange_upper = random.randint(0,20)
ShiftRange_right = image.size[0] - random.randint(0,20)
ShiftRange_lower = image.size[1] - random.randint(0,20)
image = image.crop(box=(ShiftRange_left,ShiftRange_upper,ShiftRange_right,ShiftRange_lower))
GT = GT.crop(box=(ShiftRange_left,ShiftRange_upper,ShiftRange_right,ShiftRange_lower))
if random.random() < 0.5:
image = F.hflip(image)
GT = F.hflip(GT)
if random.random() < 0.5:
image = F.vflip(image)
GT = F.vflip(GT)
Transform = T.ColorJitter(brightness=0.2,contrast=0.2,hue=0.02)
image = Transform(image)
Transform =[]
Transform.append(T.Resize((int(256*aspect_ratio)-int(256*aspect_ratio)%16,256)))
Transform.append(T.ToTensor())
Transform = T.Compose(Transform)
image = Transform(image)
GT = Transform(GT)
Norm_ = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
image = Norm_(image)
return image, GT
def __len__(self):
"""Returns the total number of font files."""
return len(self.image_paths)
- 加载数据集
torch.utils.data.Dataset
len :数据集总体样本数量
getitem :返回指定样本x
def __init__(self, root,image_size=224,mode='train',augmentation_prob=0.4):
root:图片存储的位置
image_size:resize
model :判断是train,val,test
augmentation_prob:数据增强的比例
- def getitem(self, index): 读取一张照片进行预处理并返回
预处理参考:
https://blog.csdn.net/Haiqiang1995/article/details/90416321
https://blog.csdn.net/lwplwf/article/details/85776309
solver.py
Solver
- def build_model(self):
调用相关模型
if self.model_type =='U_Net':
self.unet = U_Net(img_ch=3,output_ch=1)
self.optimizer = optim.Adam(list(self.unet.parameters()),self.lr, [self.beta1, self.beta2])
self.unet.to(self.device)
from network import U_Net网络模型
Solver.train
- self.unet.load_state_dict(torch.load(unet_path))
加载模型
关于保存与加载模型:https://www.jianshu.com/p/60fc57e19615
- self.unet.train(True)
self.unet.train(False)
self.unet.eval()
对于一些含有BatchNorm,Dropout等层的模型,在训练和验证时使用的forward在计算上不太一样。在前向训练的过程中指定当前模型是在训练还是在验证。
model.train() #使用BatchNormalizetion()和Dropout()
model.eval() #不使用BatchNormalization()和Dropout()
-
train(mode=True)
- Parameters
mode(bool)--True 为train mode, False 为 evaluation mode. Default: True
- Parameters
-
eval()
等价于 self.train(False) -
评价指标get_accuracy()等函数来自evaluation.py
评价指标python实现
Solver.test
dataset.py
用于划分数据集
参考链接:
argparse.ArgumentParser()用法解析
学习率衰减(learning rate decay)
pytorch中的model.train()和model.eval()
网友评论