概述
PyTorch的成功归功于其简单易用性(与Python的用法相似)和动态灵活性。即使在PyTorch 2.0时代,它仍然保持着"Faster, more pythonic and dynamic as ever"的核心特性。
PyTorch的动态性源自内部的调度器(dispatcher),它可以根据不同的输入类型自动选择正确的运算方式。当调用Python函数时,调度器会根据传入的参数类型选择正确的操作实现,这个过程称为分派(dispatch)。
例如,当执行矩阵乘法(torch.matmul(a, b))时,调度器会根据输入张量a和b的类型(dtype、shape、device等)选择正确的BLAS库(CPU还是CUDA,float还是half,是否批量计算)来进行计算。对于PyTorch来说,模型的执行过程就是将各个操作(op)分派给本地方法(native function)执行的过程。
http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/dispatcher 为每个 op 都维护了一张跳转表(它有点像 C++ 实现多态用的虚表),如上图所示,表中每个条目存储了一个本地方法,有些方法和输入张量所属的设备有关,比如 XLA/CUDA/CPU
,有的和 requires_grad
有关,比如 Autograd
(这图是从 ezyang’s blog 拿来的,他这篇博客详细讲解了分派机制,建议阅读)。
当 op 被执行时,e.g. aten::addmm
,调度器会在它的跳转表中找出一个方法来执行,而且一个 op 执行过程可能会调用多个方法,例如,输入张量需要求导(requires_grad = true),那会先调用 Autograd 方法来构建反向图,再调用 backend(CPU/CUDA/XLA)的方法来运算。
分派规则
http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/跳转表里的条目是以键值对的形式来存调度方法,其中“键”称为 dispatch key
,以 bit 的形式存在,bit 值越大,优先级越高,调度器会从键集(dispatch key set
)中选取优先级最高的条目来执行。
从上图可以看到,键集不只有一个,每个输入张量都有自己的键集,还有 local(local include
和local exclude
) 和 global 键集,这些键集最终会合并,调度器从中选取优先级最高的键值对应的方法来执行。
输入张量的键集是比较好理解的,张量本身具有很多属性,如 layout (dense or sparse)、shape 和 device (CPU or CUDA),一个属性对应一个 dispatch key(可以从 DispatchKey.h 找到所有的 key)。对于不同类型的张量,我们希望能使用不同实现的操作以实现高性能计算的目标。
Local 键集 与张量个体无关,与模型的行为有关,表示模型运行在某模式中,比如 tracing。它可以允许用户在某个范围内开启或关闭模式。要开启模式就是往 local include 里添加键,要关闭模式就是往 local exclude 里添加要屏蔽的键。
Global 则表示无论什么操作都会添加的键集(图中 autograd 已经从 global 移到 tensor 键集)。
分派流程
http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/前面也提到,一个 op 的执行是要经历多次分派的,上图就展示了这个过程:
- 首先,输入张量需要求导(requires_grad = true),调度器就分派给 Autograd key 的本地方法。它会为 op 生成一个反向计算操作,然后,再把控制权交给调度器做重新分派。
- 接着由于输入张量在CPU上,CPU的方法会被分派执行。
前面提到,调度器会调用优先级最高的 dispatch key,因此,重新分派的前提是将已经调度过的键从键集里清除,否则重新分派将会重复调用相同的方法。
Autograd 的本地方法通过在 local exclude 键集中添加要屏蔽的键(Autograd)来避免方法的重复调用。可以通过创建 AutoNonVariableTypeMode RAII guard 来实现:
class MyAddFunction : public torch::autograd::Function<MyAddFunction> {
public:
static Tensor forward(
AutogradContext *ctx, torch::Tensor self, torch::Tensor other) {
at::AutoNonVariableTypeMode g;
return myadd(self, other);
}
...
};
注册自定义操作
回想一下分派规则:调度器首先找到 op 对应的跳转表,合并键集,并调用键值最大的条目中的函数。由于 dispatch key 是 PyTorch 固定且不可扩展的,因此注册自定义操作需要注册 op 以及跳转表中键的方法。
注册 op
TORCH_LIBRARY(myops, m) {
m.def("myadd(Tensor self, Tensor other) -> Tensor");
}
PyTorch 提供 TORCH_LIBRARY
用于将 op(也称作 schema string
或 signature
)注册到一个库里,用户可以在 python 通过 c = torch._ops.myops.myadd(a, b)
调用该 op。
schema 与 TensorFlow 的 op_def
和 ONNX 的 node
一样,都用于描述一个操作,只是由于 PyTorch 是动态图的,schema 不需要也不能承载更多信息。
注册 dispatch function
TORCH_LIBRARY_IMPL(myops, CUDA, m) {
m.impl("myadd", myadd_cuda);
}
注册完 op 后,接着就可以通过 TORCH_LIBRARY_IMPL
注册 dispatch key 对应的方法。上述代码片段通过将 myadd_cuda
注册到键:CUDA。
除了为每个键单独注册一个方法,还可以为所有的键注册一个共同的方法,这类方法称为 catch-all
:
TORCH_LIBRARY(myops, m) {
m.def("myadd", myadd_catchall);
}
此外,还可以为所有 op 的某个键注册一个共同的 fallback
方法:
TORCH_LIBRARY_IMPL(_, XLA, m) {
m.fallback(xla_fallback);
}
除了 dispatch key 具有优先级外,这些方法也有优先级:impl > catch-all > fallback:
http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/END
PyTorch的调度器(dispatcher)和分派机制是其灵活性和高性能计算的关键。调度器根据输入类型自动选择适当的操作实现,通过分派流程将操作分派给本地方法执行。分派规则通过 dispatch key 和 keyset 确定执行方法的优先级。注册自定义操作的过程允许用户扩展PyTorch的功能。了解这些原理有助于深入理解PyTorch的内部工作机制,并为模型开发和优化提供指导。
网友评论