1, 模型结构保存:
-
1, 包含自定义层使用了Lambda层, model.to_json, Pickle会报错
解决方式:- 用Layer的层替换
- 提供一个model.to_pickle() 方法-快速解决方案(dill)
可以参考:
https://github.com/keras-team/keras/issues/2582
-
2,保存模型结构时 跑出 ('Not JSON Serializable:', Dimension(None))
解决方式:
https://blog.csdn.net/Funkdub/article/details/100069905
https://github.com/keras-team/keras/issues/9342 -
3,自定义层的初始化参数,要保存在模型结构中, 需要定义:
def get_config(self):
config = super(DilatedGatedConv1D, self).get_config()
config.update(
{
'o_dim': self.o_dim,
'k_size': self.k_size,
'rate': self.rate,
'skip_connect': self.skip_connect,
'drop_gate': self.drop_gate
}
)
return config
如果初始化参数也是一个Layer网络层, Layer对象本身不能序列化, 这就要求重新实现get_config()和from_config()两个方法,实现包含层Layer的序列化反序列化, 参考Bidirectional,Wrapper的实现
def get_config(self):
"""
参数的序列化操作
:return:
"""
config = super(OurBidirectional, self).get_config()
config.update(
{
'layer': { # 参照Wrapper 不能直接保留类对象
'class_name': self.layer.__class__.__name__,
'config': self.layer.get_config()
}
}
)
return config
@classmethod
def from_config(cls, config, custom_objects=None):
"""
自定义从字典config恢复实例参数
:param config:
:param custom_objects:
:return:
"""
layer = deserialize_layer(config.pop('layer'),
custom_objects=custom_objects)
return cls(layer, **config)
- 4, Model结构载入时, 需要用到的自定义层或者第三方类对象传给custom_objects, 否则会提示找不到类对象
def get_custom_objects(self):
"""
自定义的层或者函数
:return:
"""
custom_objects = self.embedding.get_custom_objects()
custom_objects['OurMasking'] = OurMasking
custom_objects['CRF'] = CRF
return custom_objects
keras.models.model_from_json(
model_json_str,
custom_objects=model.get_custom_objects()
)
- 5, 模型结构存储时,需要包括:
- model.to_json() :Dict/Str 模型结构参数
- config:Dict 模型参数
- class_name: Str 定义的模型类对象 self.__class__name
- module: Str self.module
模型载入时, 可以顺序的先动态import module, 然后反射类对象, 接着带着config参数来载入模型, 最后更新每层参数:
import importlib model_module = importlib.import_module((model_info['module'])) modle_class = getattr(model_module, model_info['class_name']) model.model = keras.models.model_from_json( json.dumps(model_info['model']), custom_objects=get_custom_objects() ) model.model.load_weights(os.path.join(model_path, 'best_model_weight.h5')) # 非必须, 只是如果model类本身有Model继承类的参数输入时,需要更新Model参数的训练权重 for l in model.model.layers: print(l.name) return model
网友评论