美文网首页人工不智能PyTorch/JIT
理解PyTorch分发机制的内部工作原理

理解PyTorch分发机制的内部工作原理

作者: A君来了 | 来源:发表于2023-06-02 11:13 被阅读0次

    概述

    PyTorch的成功归功于其简单易用性(与Python的用法相似)和动态灵活性。即使在PyTorch 2.0时代,它仍然保持着"Faster, more pythonic and dynamic as ever"的核心特性。

    PyTorch的动态性源自内部的调度器(dispatcher),它可以根据不同的输入类型自动选择正确的运算方式。当调用Python函数时,调度器会根据传入的参数类型选择正确的操作实现,这个过程称为分派(dispatch)。

    例如,当执行矩阵乘法(torch.matmul(a, b))时,调度器会根据输入张量a和b的类型(dtype、shape、device等)选择正确的BLAS库(CPU还是CUDA,float还是half,是否批量计算)来进行计算。对于PyTorch来说,模型的执行过程就是将各个操作(op)分派给本地方法(native function)执行的过程。

    http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/

    dispatcher 为每个 op 都维护了一张跳转表(它有点像 C++ 实现多态用的虚表),如上图所示,表中每个条目存储了一个本地方法,有些方法和输入张量所属的设备有关,比如 XLA/CUDA/CPU,有的和 requires_grad 有关,比如 Autograd(这图是从 ezyang’s blog 拿来的,他这篇博客详细讲解了分派机制,建议阅读)。

    当 op 被执行时,e.g. aten::addmm,调度器会在它的跳转表中找出一个方法来执行,而且一个 op 执行过程可能会调用多个方法,例如,输入张量需要求导(requires_grad = true),那会先调用 Autograd 方法来构建反向图,再调用 backend(CPU/CUDA/XLA)的方法来运算。

    分派规则

    http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/

    跳转表里的条目是以键值对的形式来存调度方法,其中“键”称为 dispatch key,以 bit 的形式存在,bit 值越大,优先级越高,调度器会从键集(dispatch key set)中选取优先级最高的条目来执行。

    从上图可以看到,键集不只有一个,每个输入张量都有自己的键集,还有 local(local includelocal exclude) 和 global 键集,这些键集最终会合并,调度器从中选取优先级最高的键值对应的方法来执行。

    输入张量的键集是比较好理解的,张量本身具有很多属性,如 layout (dense or sparse)、shape 和 device (CPU or CUDA),一个属性对应一个 dispatch key(可以从 DispatchKey.h 找到所有的 key)。对于不同类型的张量,我们希望能使用不同实现的操作以实现高性能计算的目标。

    Local 键集 与张量个体无关,与模型的行为有关,表示模型运行在某模式中,比如 tracing。它可以允许用户在某个范围内开启或关闭模式。要开启模式就是往 local include 里添加键,要关闭模式就是往 local exclude 里添加要屏蔽的键。

    Global 则表示无论什么操作都会添加的键集(图中 autograd 已经从 global 移到 tensor 键集)。

    分派流程

    http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/

    前面也提到,一个 op 的执行是要经历多次分派的,上图就展示了这个过程:

    • 首先,输入张量需要求导(requires_grad = true),调度器就分派给 Autograd key 的本地方法。它会为 op 生成一个反向计算操作,然后,再把控制权交给调度器做重新分派。
    • 接着由于输入张量在CPU上,CPU的方法会被分派执行。

    前面提到,调度器会调用优先级最高的 dispatch key,因此,重新分派的前提是将已经调度过的键从键集里清除,否则重新分派将会重复调用相同的方法。

    Autograd 的本地方法通过在 local exclude 键集中添加要屏蔽的键(Autograd)来避免方法的重复调用。可以通过创建 AutoNonVariableTypeMode RAII guard 来实现:

    class MyAddFunction : public torch::autograd::Function<MyAddFunction> {
     public:
      static Tensor forward(
          AutogradContext *ctx, torch::Tensor self, torch::Tensor other) {
        at::AutoNonVariableTypeMode g;
        return myadd(self, other);
      }
      ...
    };
    

    注册自定义操作

    回想一下分派规则:调度器首先找到 op 对应的跳转表,合并键集,并调用键值最大的条目中的函数。由于 dispatch key 是 PyTorch 固定且不可扩展的,因此注册自定义操作需要注册 op 以及跳转表中键的方法。

    注册 op

    TORCH_LIBRARY(myops, m) {
      m.def("myadd(Tensor self, Tensor other) -> Tensor");
    }
    

    PyTorch 提供 TORCH_LIBRARY 用于将 op(也称作 schema stringsignature)注册到一个库里,用户可以在 python 通过 c = torch._ops.myops.myadd(a, b) 调用该 op。

    schema 与 TensorFlow 的 op_def 和 ONNX 的 node 一样,都用于描述一个操作,只是由于 PyTorch 是动态图的,schema 不需要也不能承载更多信息。

    注册 dispatch function

    TORCH_LIBRARY_IMPL(myops, CUDA, m) {
      m.impl("myadd", myadd_cuda);
    }
    

    注册完 op 后,接着就可以通过 TORCH_LIBRARY_IMPL 注册 dispatch key 对应的方法。上述代码片段通过将 myadd_cuda 注册到键:CUDA。

    除了为每个键单独注册一个方法,还可以为所有的键注册一个共同的方法,这类方法称为 catch-all

    TORCH_LIBRARY(myops, m) {
      m.def("myadd", myadd_catchall);
    }
    

    此外,还可以为所有 op 的某个键注册一个共同的 fallback 方法:

    TORCH_LIBRARY_IMPL(_, XLA, m) {
      m.fallback(xla_fallback);
    }
    

    除了 dispatch key 具有优先级外,这些方法也有优先级:impl > catch-all > fallback:

    http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/

    END

    PyTorch的调度器(dispatcher)和分派机制是其灵活性和高性能计算的关键。调度器根据输入类型自动选择适当的操作实现,通过分派流程将操作分派给本地方法执行。分派规则通过 dispatch key 和 keyset 确定执行方法的优先级。注册自定义操作的过程允许用户扩展PyTorch的功能。了解这些原理有助于深入理解PyTorch的内部工作机制,并为模型开发和优化提供指导。

    相关文章

      网友评论

        本文标题:理解PyTorch分发机制的内部工作原理

        本文链接:https://www.haomeiwen.com/subject/sdwkedtx.html