from tensorflow.contrib.slim.python.slim.nets.resnet_v1 import (resnet_v1_50, resnet_arg_scope)
from tensorflow.contrib.slim.python.slim.nets.vgg import (vgg_16, vgg_arg_scope)
tra_patch = tl.layers.InputLayer(tra_patch, name='tra_Inputlayer')
with slim.arg_scope(resnet_arg_scope()):
tra_network = tl.layers.SlimNetsLayer(layer=tra_patch,
slim_layer=resnet_v1_50,
slim_args={
'num_classes': num_classes,
'is_training': True,
'global_pool': True,
},
name='resnet_v1_50')
val_patch = tl.layers.InputLayer(val_patch, name='val_Inputlayer')
with slim.arg_scope(resnet_arg_scope(set_name_reuse())):
val_network = tl.layers.SlimNetsLayer(layer=val_patch,
slim_layer=resnet_v1_50,
slim_args={
'num_classes': num_classes,
'is_training': False,
'global_pool': True,
'reuse': True
},
name='resnet_v1_50')
tra_logits = tra_network.outputs
discard dim==1
tra_logits = tf.squeeze(tra_logits)
val_logits = val_network.outputs
val_logits = tf.squeeze(val_logits)
网友评论