美文网首页
serving开发 servable类注册和pb消息反射

serving开发 servable类注册和pb消息反射

作者: peteyuan | 来源:发表于2018-10-29 13:39 被阅读76次

    serving最核心的抽象是 servables。如果需要扩展serving,最需要完成的就是自定义servables。如何自定义servable可以参考官方文章,Creating a new kind of servable。但定义好servable后如何使用,官方和网上并没有给出说明。本文先对servable注册机制进行说明。后续会以如何使用官方的hashmap来进行具体介绍。

    简单来说,对于用户实现的每一个servable,都需要注册。看官方的savedmodelsourceadapter,注册代码如下:

    REGISTER_STORAGE_PATH_SOURCE_ADAPTER(SavedModelBundleSourceAdapterCreator,
                                         SavedModelBundleSourceAdapterConfig);
    

    其调用了一个宏 REGISTER_STORAGE_PATH_SOURCE_ADAPTER,这个宏的定义如下:

    #define REGISTER_STORAGE_PATH_SOURCE_ADAPTER(ClassCreator, ConfigProto)      \
      REGISTER_CLASS(StoragePathSourceAdapterRegistry, StoragePathSourceAdapter, \
                     ClassCreator, ConfigProto);
    

    是对另一个更通用的宏 REGISTER_CLASS 的封装。REGISTER_CLASS的定义如下:

    // Registers a factory that creates subclasses of BaseClass by calling
    // ClassCreator::Create().
    #define REGISTER_CLASS(RegistryName, BaseClass, ClassCreator, config_proto, \
                           ...)                                                 \
      REGISTER_CLASS_UNIQ_HELPER(__COUNTER__, RegistryName, BaseClass,          \
                                 ClassCreator, config_proto, ##__VA_ARGS__)
    
    #define REGISTER_CLASS_UNIQ_HELPER(cnt, RegistryName, BaseClass, ClassCreator, \
                                       config_proto, ...)                          \
      REGISTER_CLASS_UNIQ(cnt, RegistryName, BaseClass, ClassCreator,              \
                          config_proto, ##__VA_ARGS__)
    
    #define REGISTER_CLASS_UNIQ(cnt, RegistryName, BaseClass, ClassCreator,    \
                                config_proto, ...)                             \
      static ::tensorflow::serving::internal::ClassRegistry<                   \
          RegistryName, BaseClass, ##__VA_ARGS__>::MapInserter                 \
          register_class_##cnt(                                                \
              (config_proto::default_instance().GetDescriptor()->full_name()), \
              (new ::tensorflow::serving::internal::ClassRegistrationFactory<  \
                  BaseClass, ClassCreator, config_proto, ##__VA_ARGS__>));
    
    

    这段代码什么作用呢?

    说白了,就是定义了一个宏,这个宏的作用是调用某个类的构造函数,在构造函数里完成了真正的注册机制。

    真正的注册代码如下:

      // Nested class whose instantiation inserts a key/value pair into the factory
      // map.
      class MapInserter {
       public:
        MapInserter(const string& config_proto_message_type, FactoryType* factory) {
          InsertIntoMap(config_proto_message_type, factory);
        }
      };
     private:
      // Inserts a key/value pair into the factory map.
      static void InsertIntoMap(const string& config_proto_message_type,
                                FactoryType* factory) {
        LockableFactoryMap* global_map = GlobalFactoryMap();
        {
          mutex_lock lock(global_map->mu);
          global_map
    

    对于每一个调用了REGISTER_CLASS的类,都会被注册,也就是插入到这个全局map中,下一次需要这个类的时候,就从这个map中查找。

    我实现了一个简单的宏定义为类构造函数的测试类。

    #include <stdio.h>
    
    class Test {
    public:
      Test(int x, int cnt) {
        printf("ddddd %d, %d\n", x, cnt);
      }
    };
    
    #define FUNC(x) static Test rr##x(x, __COUNTER__)
    #define DF(x) FUNC(x)
    
    
    //static Test a;
    DF(3);
    DF(4);
    
    int main()
    {
        printf("main...\n");
        return 0;
    }
    

    从上面可以看出,宏经过预处理后,其实就是一个全局静态变量。这个变量的初始化会在main函数执行之前,也就是类的注册会在整个程序执行之前,这一点非常重要。

    如果没有这一点保证,在下面的代码中就会报找不到类的错误。

    // Creates an instance of BaseClass based on a config proto embedded in an Any
      // message.
      //
      // Requires that the config proto in the Any has a compiled-in descriptor.
      static Status CreateFromAny(const google::protobuf::Any& any_config,
                                  AdditionalFactoryArgs... args,
                                  std::unique_ptr<BaseClass>* result) {
        // Copy the config to a proto message of the indicated type.
        string full_type_name;
        Status parse_status =
            ParseUrlForAnyType(any_config.type_url(), &full_type_name);
        if (!parse_status.ok()) {
          return parse_status;
        }
        const protobuf::Descriptor* descriptor =
            protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName(
                full_type_name);
        if (descriptor == nullptr) {
          return errors::Internal(
              "Unable to find compiled-in proto descriptor of type ",
              full_type_name);
        }
        std::unique_ptr<protobuf::Message> config(
            protobuf::MessageFactory::generated_factory()
                ->GetPrototype(descriptor)
                ->New());
        if (!any_config.UnpackTo(config.get())) {
          return errors::InvalidArgument("Malformed content of Any: ",
                                         any_config.DebugString());
        }
        return Create(*config, std::forward<AdditionalFactoryArgs>(args)...,
                      result);
      }
    

    而这个方法,就是serving中从pb消息反射具体platform类的工具。只有这一步执行成功,程序才能根据用途提供的platform配置文件 实例化出真正的 platform 实例。

    相关文章

      网友评论

          本文标题:serving开发 servable类注册和pb消息反射

          本文链接:https://www.haomeiwen.com/subject/xmdbtqtx.html