入口
if __name__ == '__main__': #主文件入口
args = parse_args()#解析参数
print('Called with args:')
print(args)
if args.cfg_file is not None:
cfg_from_file(args.cfg_file)#加载配置文件并合入到默认项
if args.set_cfgs is not None:
cfg_from_list(args.set_cfgs)#加载配置列表并合入到默认项
print('Using config:')
pprint.pprint(cfg)
np.random.seed(cfg.RNG_SEED)#生成随机种子,预测随机值
# train set
imdb, roidb = combined_roidb(args.imdb_name)
print('{:d} roidb entries'.format(len(roidb)))
# output directory where the models are saved
output_dir = get_output_dir(imdb, args.tag)
print('Output will be saved to `{:s}`'.format(output_dir))
# tensorboard directory where the summaries are saved during training
tb_dir = get_output_tb_dir(imdb, args.tag)
print('TensorFlow summaries will be saved to `{:s}`'.format(tb_dir))
# also add the validation set, but with no flipping images
orgflip = cfg.TRAIN.USE_FLIPPED
cfg.TRAIN.USE_FLIPPED = False
_, valroidb = combined_roidb(args.imdbval_name)
print('{:d} validation roidb entries'.format(len(valroidb)))
cfg.TRAIN.USE_FLIPPED = orgflip
# load network
if args.net == 'vgg16':
net = vgg16()
elif args.net == 'res50':
net = resnetv1(num_layers=50)
elif args.net == 'res101':
net = resnetv1(num_layers=101)
elif args.net == 'res152':
net = resnetv1(num_layers=152)
elif args.net == 'mobile':
net = mobilenetv1()
else:
raise NotImplementedError
train_net(net, imdb, roidb, valroidb, output_dir, tb_dir,
pretrained_model=args.weight,
max_iters=args.max_iters)
combined_roidb(imdb_names)
def combined_roidb(imdb_names):
"""
Combine multiple roidbs
"""
#内部函数
def get_roidb(imdb_name):
imdb = get_imdb(imdb_name)
print('Loaded dataset `{:s}` for training'.format(imdb.name))
imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
print('Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD))
roidb = get_training_roidb(imdb)
return roidb
roidbs = [get_roidb(s) for s in imdb_names.split('+')]
roidb = roidbs[0]
if len(roidbs) > 1:
for r in roidbs[1:]:
roidb.extend(r)
tmp = get_imdb(imdb_names.split('+')[1])
imdb = datasets.imdb.imdb(imdb_names, tmp.classes)
else:
imdb = get_imdb(imdb_names)
return imdb, roidb
网友评论