美文网首页人工不智能PyTorch/JIT
从PyTorch到TorchScript: 打通深度学习模型的生

从PyTorch到TorchScript: 打通深度学习模型的生

作者: A君来了 | 来源:发表于2023-05-21 08:28 被阅读0次

    概述

    PyTorch是一款非常流行的深度学习框架,开发者和研究者常常选择它,因为它具有灵活性、易用性和良好的性能。然而,PyTorch的灵活易用性是建立在动态计算图的基础上的,相比采用静态图的TensorFlow,PyTorch在推理性能和部署方面存在明显的劣势。

    为了解决这个问题,TorchScript应运而生。它将PyTorch模型转换为静态类型的优化序列化格式,以实现高效的优化和跨平台部署(包括C++、Python、移动设备和云端)。

    构建 TorchScript

    TorchScript将PyTorch模型转换为静态图形式,因此构建TorchScript的核心是构建模型的静态计算图。
    PyTorch提供了两种方法来构建TorchScript:trace和script。

    • torch.jit.trace:该函数接收一个已训练好的模型和实际输入样例,通过运行模型的方式来生成静态图(static graph)。这种转换方式称为"追踪模式"(tracing mode)。

    • torch.jit.script:该函数将PyTorch代码编译成静态图。与追踪模式相反,它被称为"脚本模式"(scripting mode),因为它直接将PyTorch代码翻译成静态图,而不需要追踪执行流程。

    Tracing Mode

    model = torch.nn.Sequential(nn.Linear(3, 4))
    input = torch.randn(1, 3)
    traced_model = torch.jit.trace(model, input)
    

    追踪模式通过运行模型一次,并根据操作序列生成静态图。因此,它需要提供输入样例(input)。通过追踪机制,自动捕捉和生成模型的计算图。这是许多AI编译器采用的JIT模式。然而,追踪模式存在一个问题,即无法处理控制流,例如if、while等语句。

    class MyModel(nn.Module):
      def __init__(self):
        super().__init__()
    
      def forward(self, x):
        if x > 0:
          x += 1
        else:
          x -= 1
        return x
    
    model = MyModel()
    input = torch.randn(1)
    traced_model = torch.jit.trace(model, input)
    

    if语句为例,它是Python的语句,根据具体的x值,只能在then分支或else分支上执行。因此,追踪模式只能捕捉到一个分支上的操作。要想生成完整的控制流图,需要采用脚本模式。

    Scripting Mode

    与追踪模式不同,Scripting 模式直接将 Python 和 PyTorch 的语句翻译成 TorchScript 的静态图,因此不需要追踪模型的执行流程,并且能够生成完整的控制流图:

    script_model = torch.jit.script(model)
    print(script_model.graph)
    
    
    graph(%self : __torch__.___torch_mangle_3.MyModel,
          %x.1 : Tensor):
      ......
      %x : Tensor = prim::If(%6) # <ipython-input-3-6fda6c66b1df>:6:4
        block0():
          %x.7 : Tensor = aten::add_(%x.1, %8, %8) # <ipython-input-3-6fda6c66b1df>:7:6
          -> (%x.7)
        block1():
          %x.13 : Tensor = aten::sub_(%x.1, %8, %8) # <ipython-input-3-6fda6c66b1df>:9:6
          -> (%x.13)
      return (%x)
    

    然而,这种模式也有其局限性:对于每个语句,都需要提供相应的转换函数,将 Python/PyTorch 语句转换成 TorchScript 语句。目前,PyTorch仅支持部分 Python 内置函数和 PyTorch 语句的转换。

    Tracing + Script

    因此,对于具有控制流的模型,可以采用混合模式:将追踪模式无法处理的控制流图封装为子模块,使用脚本模式来转换这些子模块,然后通过追踪机制对整个模型进行追踪(通过脚本模式转换后的子模块不会再被追踪)。有关具体实现,请参考官方示例:https://pytorch.org/docs/stable/jit.html#mixing-tracing-and-scripting

    运行 TorchScript

    pytorch_jit_forward.png

    前面生成的计算图会封装到 TorchScript 模块的 forward() 方法中,在运行时被编译成 native code(JIT)。如上图所示,特化后的计算图经过图优化后被编译成 native code,最后通过栈机解释器执行。

    Specialization

    JIT(just-in-time)将静态图编译后的结果以 <signature: executable> 键值对的形式存储在缓存中。只有在缓存未命中(cache miss)时,也就是首次运行时,才会触发编译过程。

    Signature 表示唯一的静态计算图。在计算流图不变的情况下,静态图由输入参数(arguments)决定。不同的 dtype、shape 的参数将生成不同的静态计算图。

    Specialization 的目的是根据 torchscript 的输入(Input),为参数赋予 dtype、shape、设备类型(CPU、CUDA)等静态信息(ArgumentSpec),生成 signature,以便为缓存搜索做准备。

    # post specialization, inputs are now specialized types
    graph(%x : Float(*, *),
          %hx : Float(*, *),
          %cx : Float(*, *),
          %w_ih : Float(*, *),
          %w_hh : Float(*, *),
          %b_ih : Float(*),
          %b_hh : Float(*)):
      %7 : int = prim::Constant[value=4]()
      %8 : int = prim::Constant[value=1]()
      %9 : Tensor = aten::t(%w_ih)
    

    Optimization

    PyTorch JIT 使用一系列 passes(torch.jit.passes)对图进行优化,旨在从执行效率、内存占用等方面优化计算。其中包括对 dtype、shape 和常量进行前向推导的形状推导(Shape inference)和常数传播(Const propagation)等优化,以减少实际操作的数量。

    除了上述常见的优化,对于 GPU 来说,最核心的优化是算子融合(Operation fusion):将匹配的一组算子合并为一个算子。例如,将连续的一系列 element-wise 操作合并为一个操作,这样可以减少 CUDA kernels 的启动时间开销,并减少操作之间访问全局内存的次数。

    图优化是 AI 编译器的标配,用于优化计算图的执行效率和内存占用等方面。PyTorch 的图优化通过一系列 passes(torch.jit.passes)来实现,包括常数折叠(Constant folding)、死代码清除(Dead code elimination)和算子融合(Operation fusion)等。

    在图优化过程中,FuseGraph pass 将可以融合的算子封装为 FusionGroup 静态子图:

    graph(%x : Float(*, *),
          ...):
      %9 : Float(*, *) = aten::t(%w_ih)
      ...
      %77 : Tensor[] = prim::ListConstruct(%b_hh, %b_ih, %10, %12)
      %78 : Tensor[] = aten::broadcast_tensors(%77)
      %79 : Tensor, %80 : Tensor, %81 : Tensor, %82 : Tensor = prim::ListUnpack(%78)
      %hy : Float(*, *), %cy : Float(*, *) = prim::FusionGroup_0(%cx, %82, %81, %80, %79)
      %30 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy)
      return (%30);
    
    with prim::FusionGroup_0 = graph(%13 : Float(*, *),
      ...):
      %87 : Float(*, *), %88 : Float(*, *), %89 : Float(*, *), %90 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%86)
      %82 : Float(*, *), %83 : Float(*, *), %84 : Float(*, *), %85 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%81)
      %77 : Float(*, *), %78 : Float(*, *), %79 : Float(*, *), %80 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%76)
      %72 : Float(*, *), %73 : Float(*, *), %74 : Float(*, *), %75 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%71)
      %69 : int = prim::Constant[value=1]()
      %70 : Float(*, *) = aten::add(%77, %72, %69)
      %66 : Float(*, *) = aten::add(%78, %73, %69)
      ...
      %4 : Float(*, *) = aten::tanh(%cy)
      %hy : Float(*, *) = aten::mul(%outgate, %4)
      return (%hy, %cy)
    

    Codegen

    优化的最后是为图(symbolic graph)中的符号操作生成加速器所需的操作内核(op kernel)。PyTorch 已经为 CPU 和 Nvidia GPU 提供了一个名为 ATen 的 C++ 算子库,像图中的 aten::add 节点就会在运行时调用 built-in 算子。

    对于融合算子,PyTorch 提供了基于 LLVM 的 NNC 编译器,用于生成相应的目标代码。它将 FusionGroup 子图里的 node lowering 成 C++ functions,再基于 LLVM 将它们编译成一个大算子:

      RegisterNNCLoweringsFunction aten_matmul(
          {"aten::mm(Tensor self, Tensor mat2) -> (Tensor)",
           "aten::matmul(Tensor self, Tensor other) -> (Tensor)"},
          computeMatmul);
    
      Tensor computeMatmul(...) {
        ...
        return Tensor(
            ResultBuf.node(),
            ExternalCall::make(ResultBuf, "nnc_aten_matmul", {a, b}, {}));
      }
    
      void nnc_aten_matmul(...) {
        ...
        try {
          at::matmul_out(r, self, other);
        } catch (...) {}
      }
    

    Interpreter

    TorchScript 提供一个栈机解释器在C++上高效地运行计算图:

    // Create a vector of inputs.
    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(torch::ones({1, 3, 224, 224}));
    
    // Execute the model and turn its output into a tensor.
    at::Tensor output = module.forward(inputs).toTensor();
    std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';
    

    END

    相关文章

      网友评论

        本文标题:从PyTorch到TorchScript: 打通深度学习模型的生

        本文链接:https://www.haomeiwen.com/subject/mzcssdtx.html