美文网首页
Keras使用问题记录

Keras使用问题记录

作者: 点点渔火 | 来源:发表于2019-10-21 14:27 被阅读0次

1, 模型结构保存:

    def get_config(self):
        config = super(DilatedGatedConv1D, self).get_config()
        config.update(
            {
                'o_dim': self.o_dim,
                'k_size': self.k_size,
                'rate': self.rate,
                'skip_connect': self.skip_connect,
                'drop_gate': self.drop_gate
             }
        )
        return config

如果初始化参数也是一个Layer网络层, Layer对象本身不能序列化, 这就要求重新实现get_config()和from_config()两个方法,实现包含层Layer的序列化反序列化, 参考Bidirectional,Wrapper的实现

    def get_config(self):
        """
        参数的序列化操作
        :return:
        """
        config = super(OurBidirectional, self).get_config()
        config.update(
            {
                'layer': {       # 参照Wrapper 不能直接保留类对象
                    'class_name': self.layer.__class__.__name__,
                    'config': self.layer.get_config()
                }
            }
        )
        return config

    @classmethod
    def from_config(cls, config, custom_objects=None):
        """
        自定义从字典config恢复实例参数
        :param config:
        :param custom_objects:
        :return:
        """
        layer = deserialize_layer(config.pop('layer'),
                                  custom_objects=custom_objects)
        return cls(layer, **config)
  • 4, Model结构载入时, 需要用到的自定义层或者第三方类对象传给custom_objects, 否则会提示找不到类对象
def get_custom_objects(self):
        """
        自定义的层或者函数
        :return:
        """
        custom_objects = self.embedding.get_custom_objects()
        custom_objects['OurMasking'] = OurMasking
        custom_objects['CRF'] = CRF
        return custom_objects

keras.models.model_from_json(
            model_json_str,
            custom_objects=model.get_custom_objects()
        )
  • 5, 模型结构存储时,需要包括:
    • model.to_json() :Dict/Str 模型结构参数
    • config:Dict 模型参数
    • class_name: Str 定义的模型类对象 self.__class__name
    • module: Str self.module
      模型载入时, 可以顺序的先动态import module, 然后反射类对象, 接着带着config参数来载入模型, 最后更新每层参数:
    import importlib
    model_module = importlib.import_module((model_info['module']))
    modle_class =  getattr(model_module, model_info['class_name'])
    model.model = keras.models.model_from_json(
            json.dumps(model_info['model']),
            custom_objects=get_custom_objects()
        )
    model.model.load_weights(os.path.join(model_path, 
    'best_model_weight.h5'))
    # 非必须, 只是如果model类本身有Model继承类的参数输入时,需要更新Model参数的训练权重
    for l in model.model.layers:
        print(l.name)
    return model
    

相关文章

网友评论

      本文标题:Keras使用问题记录

      本文链接:https://www.haomeiwen.com/subject/tpdtvctx.html