美文网首页
Keras使用问题记录

Keras使用问题记录

作者: 点点渔火 | 来源:发表于2019-10-21 14:27 被阅读0次

    1, 模型结构保存:

        def get_config(self):
            config = super(DilatedGatedConv1D, self).get_config()
            config.update(
                {
                    'o_dim': self.o_dim,
                    'k_size': self.k_size,
                    'rate': self.rate,
                    'skip_connect': self.skip_connect,
                    'drop_gate': self.drop_gate
                 }
            )
            return config
    

    如果初始化参数也是一个Layer网络层, Layer对象本身不能序列化, 这就要求重新实现get_config()和from_config()两个方法,实现包含层Layer的序列化反序列化, 参考Bidirectional,Wrapper的实现

        def get_config(self):
            """
            参数的序列化操作
            :return:
            """
            config = super(OurBidirectional, self).get_config()
            config.update(
                {
                    'layer': {       # 参照Wrapper 不能直接保留类对象
                        'class_name': self.layer.__class__.__name__,
                        'config': self.layer.get_config()
                    }
                }
            )
            return config
    
        @classmethod
        def from_config(cls, config, custom_objects=None):
            """
            自定义从字典config恢复实例参数
            :param config:
            :param custom_objects:
            :return:
            """
            layer = deserialize_layer(config.pop('layer'),
                                      custom_objects=custom_objects)
            return cls(layer, **config)
    
    • 4, Model结构载入时, 需要用到的自定义层或者第三方类对象传给custom_objects, 否则会提示找不到类对象
    def get_custom_objects(self):
            """
            自定义的层或者函数
            :return:
            """
            custom_objects = self.embedding.get_custom_objects()
            custom_objects['OurMasking'] = OurMasking
            custom_objects['CRF'] = CRF
            return custom_objects
    
    keras.models.model_from_json(
                model_json_str,
                custom_objects=model.get_custom_objects()
            )
    
    • 5, 模型结构存储时,需要包括:
      • model.to_json() :Dict/Str 模型结构参数
      • config:Dict 模型参数
      • class_name: Str 定义的模型类对象 self.__class__name
      • module: Str self.module
        模型载入时, 可以顺序的先动态import module, 然后反射类对象, 接着带着config参数来载入模型, 最后更新每层参数:
      import importlib
      model_module = importlib.import_module((model_info['module']))
      modle_class =  getattr(model_module, model_info['class_name'])
      model.model = keras.models.model_from_json(
              json.dumps(model_info['model']),
              custom_objects=get_custom_objects()
          )
      model.model.load_weights(os.path.join(model_path, 
      'best_model_weight.h5'))
      # 非必须, 只是如果model类本身有Model继承类的参数输入时,需要更新Model参数的训练权重
      for l in model.model.layers:
          print(l.name)
      return model
      

    相关文章

      网友评论

          本文标题:Keras使用问题记录

          本文链接:https://www.haomeiwen.com/subject/tpdtvctx.html