两种方法:
直接读取模型,以下两种方式都是不可以的
错误方法一:
checkpoint_exclude_scopes = "voxel_decoder,detail"
if checkpoint_exclude_scopes:
exclusions = [scope.strip()
for scope in checkpoint_exclude_scopes.split(',')]
variables_to_restore = []
for var in slim.get_model_variables():
excluded = False
for exclusion in exclusions:
if var.op.name.startswith(exclusion):
excluded = True
break
if not excluded:
variables_to_restore.append(var)
save = tf.train.Saver(variables_to_restore)
save.restore(sess, save_path)
错误方法2
checkpoint_include_scopes = "gaussian_MLP_encoder,bernoulli_MLP_decoder"
if checkpoint_include_scopes:
inclusions = [scope.strip()
for scope in checkpoint_include_scopes.split(',')]
variables_to_restore = []
for var in slim.get_model_variables():
for inclusion in inclusions:
if var.op.name.startswith(inclusion):
variables_to_restore.append(var)
save = tf.train.Saver(variables_to_restore)
save.restore(sess, save_path)
正确方案1:
inclusion = ['is_training','encoder','decoder']
inception_except_logits = slim.get_variables_to_restore(include=inclusion)
init_fn = slim.assign_from_checkpoint_fn(part_save_path, inception_except_logits, ignore_missing_vars=True)
init_fn(sess)
正确方案2:
exclusion = ['is_training','encoder','decoder']
inception_except_logits = slim.get_variables_to_restore(exclude=exclusion)
init_fn = slim.assign_from_checkpoint_fn(part_save_path, inception_except_logits, ignore_missing_vars=True)
init_fn(sess)
网友评论