美文网首页
看PyTorch源代码的心路历程

看PyTorch源代码的心路历程

作者: SunnyZhou1024 | 来源:发表于2021-01-08 12:54 被阅读0次

    1. 起因

    曾经碰到过别人的模型prelu在内部的推理引擎算出的结果与其在原始框架PyTorch中不一致的情况,虽然理论上大家实现的都是一个算法,但是从参数上看,因为经过了模型转换,中间做了一些调整。为了确定究竟是初始参数传递就出了问题还是在后续传递过程中继续做了更改、亦或者是最终算法实现方面有着细微差别导致最终输出不同,就想着去看一看PyTorch一路下来是怎么做的。
    但是代码跟着跟着就跟丢了,才会发现,PyTorch真的是一个很复杂的项目,但就像舌尖里面说的,环境越是恶劣,回报越是丰厚。为了以后再想跟踪的时候方便,因此决定以PReLU为例静态梳理一下PyTorch的代码结构。捣鼓的这些天,对如何构建一个带有C/C++代码的Python又有了新的了解,这也算是意外的收获吧。

    2. 历程

    首先,我们从PReLU的导入路径torch.nn.PReLU中知道,他应在径进torch\nn\之下,进入该路径虽然没看到,但是我们在该路径下的__init__.py中知道,其实它就在torch\nn\modules\activation.py中。类PReLU最终调用了从torch\nn\functional.py导入的prelu方法。顺腾摸瓜,找到prelu,它长下面这样:

    def prelu(input, weight):
        # type: (Tensor, Tensor) -> Tensor
        if not torch.jit.is_scripting(): 
            if type(input) is not Tensor and has_torch_function((input,)):
                return handle_torch_function(prelu, (input,), input, weight)
        return torch.prelu(input, weight)
    

    经过人脑对代码的一番执行你会发现,第一个if条件满足,而第二个if不满足。因此,最终想看算法,得去看torch.prelu()。好吧,接着干……

    一番搜寻之后你会发现,Python代码中在torch这个包下面你是找不到prelu的定义的。但是绝望之际我们在torch包的__init__.py之中看到看下面几行代码:

    # pytorch\torch\__init__.py
    
    # 为了简洁,省去不必要代码,详细代码参见pytorch\torch\__init__.py
    try:
        # _initExtension is chosen (arbitrarily) as a sentinel.
        from torch._C import _initExtension
    
    
    __all__ += [name for name in dir(_C)
                if name[0] != '_' and
                not name.endswith('Base')]
    
    if TYPE_CHECKING:
        # Some type signatures pulled in from _VariableFunctions here clash with
        # signatures already imported. For now these clashes are ignored; see
        # PR #43339 for details.
        from torch._C._VariableFunctions import *  # type: ignore
    
    for name in dir(_C._VariableFunctions):
        if name.startswith('__'):
            continue
        globals()[name] = getattr(_C._VariableFunctions, name)
        __all__.append(name)
    

    这是全村最后的希望了。我们知道__all__中的名字其实就是该模块有意暴露出去的API。
    什么意思呢?也就是说虽然我们明文上已经看不到了prelu的定义,但是这几行代码表明有一大堆身份不明的API被暗搓搓的导入了,这其中就很有可能存在我们朝思暮想的prelu

    那么我们怎么凭借这么一点微弱的线索确定我们的猜测到底对不对呢?这里我们就用到了Python的一个关键知识:C/C++扩展。(戳这里《使用C语言编写Python模块-引子》《Python调用C++之PYBIND11简介》了解更多)

    我们知道Python C/C++扩展有着固定的格式,只要我们找到模块初始化入口,就能顺藤摸瓜找到该模块暴露的给Python解释器所有函数。Python 3中的初始化函数样子为PyInit_<module_name>,其中<module_name>就是模块的名字。例如在前面提到的from torch._C import *中,模块torch下面必要有一个名字为_C的子模块。因此它的初始化函数应该为PyInit__C,我们搜索该名字就能找到模块入口。当然另外还有一种方法,就是查看setup.py文件中关于扩展的描述信息:

    // pytorch\setup.py
    main_sources = ["torch/csrc/stub.c"]
    C = Extension("torch._C",
                      libraries=main_libraries,
                      sources=main_sources,
                      language='c',
                      extra_compile_args=main_compile_args + extra_compile_args,
                      include_dirs=[],
                      library_dirs=library_dirs,
                      extra_link_args=extra_link_args + main_link_args + make_relative_rpath_args('lib'))
        extensions.append(C)
    

    不管是通过搜索还是查看setup.py,我们最终都成功定位到了位于pytorch\torch\csrc\stub.c下的模块初始化函数PyInit__C(void),并进一步跟踪其调用的函数initModule(),便可以知道具体都暴露了哪些API给Python解释器。

    // pytorch\torch\csrc\stub.c
    PyMODINIT_FUNC PyInit__C(void)
    {
      return initModule();
    }
    
    
    // pytorch\torch\csrc\Module.cpp
    initModule()
    

    进入initModule()寻找一番,你会发现,模块_C中依然没有prelu的Python接口。怎么办?莫慌,通过前面对torch.__init__.py的分析,我们知道我们还有希望——_C模块下的子模块_VariableFunctions,这真的是最后的希望了!没了别的路可以走了,只能是硬着头皮找。经过一番惊天地泣鬼神、艰苦卓绝的寻找,我们在initModule()的调用链initModule()->THPVariable_initModule(module)->torch::autograd::initTorchFunctions(module)中发现了_VariableFunctions的踪影。Aha,simple!

    void initTorchFunctions(PyObject* module) {
      if (PyType_Ready(&THPVariableFunctions) < 0) {
        throw python_error();
      }
      Py_INCREF(&THPVariableFunctions);
    
      // Steals
      Py_INCREF(&THPVariableFunctions);
      if (PyModule_AddObject(module, "_VariableFunctionsClass", reinterpret_cast<PyObject*>(&THPVariableFunctions)) < 0) {
        throw python_error();
      }
      // PyType_GenericNew returns a new reference
      THPVariableFunctionsModule = PyType_GenericNew(&THPVariableFunctions, Py_None, Py_None);
      // PyModule_AddObject steals a reference
      if (PyModule_AddObject(module, "_VariableFunctions", THPVariableFunctionsModule) < 0) {
        throw python_error();
      }
    }
    

    但是!!别高兴太早!查看模块_VariableFunctions中暴露的接口你会发现,根本就没有我们想要的!如下面的代码所示:

    static PyMethodDef torch_functions[] = {
      {"arange", castPyCFunctionWithKeywords(THPVariable_arange),
        METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
      {"as_tensor", castPyCFunctionWithKeywords(THPVariable_as_tensor),
        METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
      {"dsmm", castPyCFunctionWithKeywords(THPVariable_mm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
      {"from_numpy", THPVariable_from_numpy, METH_STATIC | METH_O, NULL},
      {"full", castPyCFunctionWithKeywords(THPVariable_full), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
      {"hsmm", castPyCFunctionWithKeywords(THPVariable_hspmm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
      {"nonzero", castPyCFunctionWithKeywords(THPVariable_nonzero), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
      {"randint", castPyCFunctionWithKeywords(THPVariable_randint), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
      {"range", castPyCFunctionWithKeywords(THPVariable_range), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
      {"saddmm", castPyCFunctionWithKeywords(THPVariable_sspaddmm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
      {"sparse_coo_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_coo_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
      {"_sparse_coo_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_coo_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
      {"_validate_sparse_coo_tensor_args", castPyCFunctionWithKeywords(THPVariable__validate_sparse_coo_tensor_args), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
      {"spmm", castPyCFunctionWithKeywords(THPVariable_mm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
      {"tensor", castPyCFunctionWithKeywords(THPVariable_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
      {"get_device", castPyCFunctionWithKeywords(THPVariable_get_device), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
      {"numel", castPyCFunctionWithKeywords(THPVariable_numel), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
      ${py_method_defs}
      {NULL}
    };
    

    上面的代码中我们找不到prelu的任何身影。会不会prelu可以绕开C/C++扩展的方式直接被Python使用呢?所以不会出现在这里?答案是不会,自古华山一条路,程序是不会跟你讲潜规则的。那么既然最终代码已经跟丢了,作者一定是使用了黑魔法,作为麻瓜的我无计可施,本文也该结束了……

    等等,上面的C代码中好像混入了奇怪的东西——${py_method_defs}。这种语法好像C/C++语法里面是没有的,反而是Shell这类脚本里面才会有,难道是新特性?费劲查找了一圈,并没有发现C/C++中有这种语法,既然不是正经语法,那么混入C/C++中肯定会导致编译失败,但是它确实就在那里。那么真相只有一个:它就是个占位符,后面肯定会有真正的代码替换它!

    接下来怎么办?搜索!使用py_method_defs作为关键字全局搜索,最终我们会发现,确实是有一个Python脚本对这个占位符进行了替换,而替换的结果就是我们一直寻找的prelu终于出现在了模块_VariableFunctions之中。好,破案了。

    但是就像警察破案,即便有单个证据,也要找到其他证据形成完整证据链才能使得证据具有说服力。虽然我们通过搜索得知了prelu会出现在模块_VariableFunctions中,但是它究竟怎么来的目前还是很模糊:占位符在什么时候被谁调用的脚本进行了替换?

    实际上,这一切都是有迹可循的。踪迹依旧在setup.py中。进入setup.py的主函数,在调用setup函数之前会看到一个名为build_deps()的函数调用,此函数最终会调用指定平台的CMake去按照根目录下CMakeLists.txt中的脚本进行构建。根目录下的CMakeLists.txt最终又会调用到caffe2目录下的CMakeLists.txt(add_subdirectory(caffe2)),而caffe2/CMakeLists.txt中就会调用到进行代码生成的Python脚本,如下所示:

    代码生成脚本起调过程示意图
    // pytorch\caffe2\CMakeLists.txt
      add_custom_command( OUTPUT
        ${TORCH_GENERATED_CODE}
        COMMAND
        "${PYTHON_EXECUTABLE}" tools/setup_helpers/generate_code.py
          --declarations-path "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml"
          --native-functions-path "aten/src/ATen/native/native_functions.yaml"
          --nn-path "aten/src"
          $<$<BOOL:${INTERN_DISABLE_AUTOGRAD}>:--disable-autograd>
          $<$<BOOL:${SELECTED_OP_LIST}>:--selected-op-list-path="${SELECTED_OP_LIST}">
          --force_schema_registration
    

    进行代码生成的主要流程如下面代码块所示,其大概流程是main()先解析传递给脚本的参数,之后将参数传递给generate_code()。结合caffe2/CMakeLists.txt中脚本调用时传递的参数可知,generate_code()中的是三个gen_*()函数都得到了调用,而在gen_autograd_python()会调用到一个名为create_python_bindings()的函数,这个函数就是真正执行代码生成的地方。

    代码生成器调用流程示意图
    // tools/setup_helpers/generate_code.py
    def generate_code(ninja_global=None,
                      declarations_path=None,
                      nn_path=None,
                      native_functions_path=None,
                      install_dir=None,
                      subset=None,
                      disable_autograd=False,
                      force_schema_registration=False,
                      operator_selector=None):
    
        if subset == "pybindings" or not subset:
            gen_autograd_python(
                declarations_path or DECLARATIONS_PATH,
                native_functions_path or NATIVE_FUNCTIONS_PATH,
                autograd_gen_dir,
                autograd_dir)
    
        if operator_selector is None:
            operator_selector = SelectiveBuilder.get_nop_selector()
    
        if subset == "libtorch" or not subset:
    
            gen_autograd(
                declarations_path or DECLARATIONS_PATH,
                native_functions_path or NATIVE_FUNCTIONS_PATH,
                autograd_gen_dir,
                autograd_dir,
                disable_autograd=disable_autograd,
                operator_selector=operator_selector,
            )
    
        if subset == "python" or not subset:
            gen_annotated(
                native_functions_path or NATIVE_FUNCTIONS_PATH,
                python_install_dir,
                autograd_dir)
    
    def main():
        parser = argparse.ArgumentParser(description='Autogenerate code')
        parser.add_argument('--declarations-path')
        parser.add_argument('--native-functions-path')
        parser.add_argument('--nn-path')
        parser.add_argument('--ninja-global')
        parser.add_argument('--install_dir')
        parser.add_argument(
            '--subset',
            help='Subset of source files to generate. Can be "libtorch" or "pybindings". Generates both when omitted.'
        )
        parser.add_argument(
            '--disable-autograd',
            default=False,
            action='store_true',
            help='It can skip generating autograd related code when the flag is set',
        )
        parser.add_argument(
            '--selected-op-list-path',
            help='Path to the YAML file that contains the list of operators to include for custom build.',
        )
        parser.add_argument(
            '--operators_yaml_path',
            help='Path to the model YAML file that contains the list of operators to include for custom build.',
        )
        parser.add_argument(
            '--force_schema_registration',
            action='store_true',
            help='force it to generate schema-only registrations for ops that are not'
            'listed on --selected-op-list'
        )
        options = parser.parse_args()
    
        generate_code(
            options.ninja_global,
            options.declarations_path,
            options.nn_path,
            options.native_functions_path,
            options.install_dir,
            options.subset,
            options.disable_autograd,
            options.force_schema_registration,
            # options.selected_op_list
            operator_selector=get_selector(options.selected_op_list_path, options.operators_yaml_path),
        )
    
    if __name__ == "__main__":
        main()
    
    // pytorch\tools\autograd\gen_autograd.py
    def gen_autograd_python(aten_path, native_functions_path, out, autograd_dir):
        from .load_derivatives import load_derivatives
        differentiability_infos = load_derivatives(
            os.path.join(autograd_dir, 'derivatives.yaml'), native_functions_path)
    
        template_path = os.path.join(autograd_dir, 'templates')
    
        # Generate Functions.h/cpp
        from .gen_autograd_functions import gen_autograd_functions_python
        gen_autograd_functions_python(
            out, differentiability_infos, template_path)
    
        # Generate Python bindings
        from . import gen_python_functions
        deprecated_path = os.path.join(autograd_dir, 'deprecated.yaml')
        gen_python_functions.gen(
            out, native_functions_path, deprecated_path, template_path)
    
    // pytorch\tools\autograd\gen_python_functions.py
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
    #
    #                            Main Function
    #
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
    
    def gen(out: str, native_yaml_path: str, deprecated_yaml_path: str, template_path: str) -> None:
        fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
    
        methods = load_signatures(native_yaml_path, deprecated_yaml_path, method=True)
        create_python_bindings(
            fm, methods, is_py_variable_method, None, 'python_variable_methods.cpp', method=True)
    
        functions = load_signatures(native_yaml_path, deprecated_yaml_path, method=False)
        create_python_bindings(
            fm, functions, is_py_torch_function, 'torch', 'python_torch_functions.cpp', method=False)
    
        create_python_bindings(
            fm, functions, is_py_nn_function, 'torch.nn', 'python_nn_functions.cpp', method=False)
    
        create_python_bindings(
            fm, functions, is_py_fft_function, 'torch.fft', 'python_fft_functions.cpp', method=False)
    
        create_python_bindings(
            fm, functions, is_py_linalg_function, 'torch.linalg', 'python_linalg_functions.cpp', method=False)
    
    def create_python_bindings(
        fm: FileManager,
        pairs: Sequence[PythonSignatureNativeFunctionPair],
        pred: Callable[[NativeFunction], bool],
        module: Optional[str],
        filename: str,
        *,
        method: bool,
    ) -> None:
        """Generates Python bindings to ATen functions"""
        py_methods: List[str] = []
        py_method_defs: List[str] = []
        py_forwards: List[str] = []
    
        grouped: Dict[BaseOperatorName, List[PythonSignatureNativeFunctionPair]] = defaultdict(list)
        for pair in pairs:
            if pred(pair.function):
                grouped[pair.function.func.name.name].append(pair)
    
        for name in sorted(grouped.keys(), key=lambda x: str(x)):
            overloads = grouped[name]
            py_methods.append(method_impl(name, module, overloads, method=method))
            py_method_defs.append(method_def(name, module, overloads, method=method))
            py_forwards.extend(forward_decls(name, overloads, method=method))
    
        fm.write_with_template(filename, filename, lambda: {
            'generated_comment': '@' + f'generated from {fm.template_dir}/{filename}',
            'py_forwards': py_forwards,
            'py_methods': py_methods,
            'py_method_defs': py_method_defs,
        })
    

    最终通过查看native_functions.yaml的内容以及深入跟踪加载native_functions.yaml的代码发现,native_functions.yaml中的prelu最终会被写到以python_torch_functions.cpp为模板的文件中,也就是调用

        create_python_bindings(
            fm, functions, is_py_torch_function, 'torch', 'python_torch_functions.cpp', method=False)
    

    的时候被生成。整个生成的过程其实是很繁琐的,一层层跟踪后可以发现,最终生成的代码可以实现将一个名为at::<func_name>的函数暴露给Python。例如我们的prelu,暴露给Python的API最终会调用一个名为at::prelu()的函数来做真正的计算。那么这个at::<func_name>(例如at::prelu())的定义又在哪里呢?

    还是一样,故技重施!仍然使用Python脚本根据native_functions.yaml文件中的内容去以pytorch\aten\src\ATen\templates目录下的各种模板去生成对应的实际C++源文件。最终结果是得到at::<func_name>,在这个函数中,它调用了Dispatcher这个类寻找到目标函数的句柄。通常情况下能够使用的函数句柄都通过一个叫Library的类来管理。Python脚本以RegisterSchema.cpp为模板,生成了注册这些目标函数的注册代码,并通过一个名为TORCH_LIBRARY的宏调用Library类来注册管理。

    #define TORCH_LIBRARY(ns, m) \
      static void TORCH_LIBRARY_init_ ## ns (torch::Library&); \
      static const torch::detail::TorchLibraryInit TORCH_LIBRARY_static_init_ ## ns ( \
        torch::Library::DEF, \
        &TORCH_LIBRARY_init_ ## ns, \
        #ns, c10::nullopt, __FILE__, __LINE__ \
      ); \
      void TORCH_LIBRARY_init_ ## ns (torch::Library& m)
    
    class TorchLibraryInit final {
    private:
      using InitFn = void(Library&);
      Library lib_;
    public:
      TorchLibraryInit(Library::Kind kind, InitFn* fn, const char* ns, c10::optional<c10::DispatchKey> k, const char* file, uint32_t line)
        : lib_(kind, ns, k, file, line) {
        fn(lib_);
      }
    };
    
    PyTorch组成示意图

    3. 总结

    PyTorch虽然在使用上是非常的Pythonic,但实际上Python只不过是为了方便使用裹在C++代码上的一层糖衣。用起来虽然好用,但是看起来实在是非常费劲,特别是如果静态的梳理代码,很多用于连接Python C/C++接口与实际逻辑代码之间的C++代码都是通过Python脚本生成的。至此,整个大的线索已经摸清了,剩下的就是去查看具体细节的实现。

    说实话,人脑执行Python代码之后再去理解C++代码实在是费劲,也费头发。因此我决定的让电脑去生成C++代码再接着看更具体的细节,比如究竟每一个算子是怎么注册到Library之中的。

    4. Bonus

    我真心怀疑我们生活在一个虚拟机里,为什么呢?因为到处可见运用于计算机里面的空间和时间局部性原理的实例。就在我写完这个博客的时候,意外的发现了一篇PyTorch工程师讲解PyTorch内部原理的博文,这对后续读代码应该会有很大帮助。等不及就戳它吧 http://blog.ezyang.com/2019/05/pytorch-internals/

    相关文章

      网友评论

          本文标题:看PyTorch源代码的心路历程

          本文链接:https://www.haomeiwen.com/subject/bqvgnktx.html