美文网首页人工不智能PyTorch/JIT
PyTorch Internal:算子注册

PyTorch Internal:算子注册

作者: A君来了 | 来源:发表于2023-09-29 11:58 被阅读0次

    Overhead

    PyTorch 执行 eager 操作时,例如,torch.add(a, b),调度器(c10::Dispatcher)会根据分派键(DispatchKey) 来查找并执行 add op 的 op kernel (理解PyTorch分发机制的内部工作原理)。因此,算子注册过程就是在调度器中定义 op,并将 kernel function 注册到 op 的指定分派键条目中。

    Torch Library

    torch::Library 是算子注册用的 helper,通过它注册的算子有着相同的命名空间、dispatch key等。

    TORCH_LIBRARY(myops, m) {
      m.def("myadd(Tensor self, Tensor other) -> Tensor");
      m.def("mysub(Tensor self, Tensor other) -> Tensor", mysub_func);
      m.impl("myadd", myadd_func);
    }
    

    m 就是命名空间为 myops 的 library,它通过 m.def 定义了 myadd 和 mysub 这两个 op 的静态信息 schema。mysub 在定义的同时也将 mysub_func 函数注册到 op,而 myadd 的 op kernel 则是通过 m.impl 单独注册的。由于 TORCH_LIBRARY 宏没有指定 dispatch key,因此,这两个 op kernel 都是 CatchAll 函数。

    如果要将 kernel function 注册到指定的 dispatch key,需要用到 TORCH_LIBRARY_IMPL 宏:

    TORCH_LIBRARY_IMPL(myops, CUDA, m) {
      m.impl("myadd", myadd_cuda);
      m.impl("mysub", mysub_cuda);
    }
    

    所有通过 m 注册的 kernel function 都会注册到 op 的 CUDA key 条目中,它执行的优先级会比 CatchAll 更高。

    OperatorDef

    OperatorDef 用于描述调度器中 op 的静态信息,它会提供 registerSchema()registerKernel() 方法给 m.def() 和 m.impl() 分别用于注册 op 和 kernel。

    Kernel list

    通过 m.impl() 注册的 kernel function 会插入到指定 dispatch key 的 kernel list(kernels_)的头部,而调度器则会从列表中的首元素中获取 kernel。也就是说,PyTorch 允许为 op 的同一个 dispatch key 注册多个 kernel,而新 kernel 会覆盖旧 kernel。

    class TORCH_API OperatorEntry final {
      ...
      ska::flat_hash_map<DispatchKey, std::list<AnnotatedKernel>> kernels_;
    };
    
    const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey(DispatchKey dispatch_key) const{
      auto kern_it = kernels_.find(dispatch_key);
      if (kern_it != kernels_.end()) {
        TORCH_INTERNAL_ASSERT(!kern_it->second.empty());
        TORCH_INTERNAL_ASSERT(kern_it->second.front().kernel.isValid());
        return &kern_it->second.front();
      }
      return nullptr;
    }
    

    End

    相关文章

      网友评论

        本文标题:PyTorch Internal:算子注册

        本文链接:https://www.haomeiwen.com/subject/cmkkbdtx.html