美文网首页
C++ windows调用ubuntu训练的PyTorch模型(

C++ windows调用ubuntu训练的PyTorch模型(

作者: 有事没事扯扯淡 | 来源:发表于2020-09-04 14:50 被阅读0次

    之前写到用ubuntu上训练pytorch的网络,再window c++调用的文章。实际调用过程中,并不是简单load就可以,需要将PyTorch模型转换为Torch Script,在这里记录一下,供有需要的盆友参考。

    PyTorch模型从Python到C++的转换由Torch Script实现。Torch Script是PyTorch模型的一种表示,可由Torch Script编译器理解,编译和序列化。

    将PyTorch模型转换为Torch Script有两种方法。

    • 第一种方法是Tracing。该方法通过将样本输入到模型中一次来对该过程进行评估从而捕获模型结构.并记录该样本在模型中的flow。该方法适用于模型中很少使用控制flow的模型。(就是没有if...else... for ...这种的。。)
    • 第二个方法就是向模型添加显式注释(Annotation),通知Torch Script编译器它可以直接解析和编译模型代码,受Torch Script语言强加的约束。
    利用Tracing将模型转换为Torch Script

    要通过tracing来将PyTorch模型转换为Torch脚本,必须将模型的实例以及样本输入传递给torch.jit.trace函数。这将生成一个 torch.jit.ScriptModule对象,并在模块的forward方法中嵌入模型评估的跟踪(我用的就是这个):

    import torch
    import torchvision
    
    # 获取模型实例
    model = torchvision.models.resnet18()
    
    # 生成一个样本供网络前向传播 forward()
    example = torch.rand(1, 3, 224, 224)
    
    # 使用 torch.jit.trace 生成 torch.jit.ScriptModule 来跟踪
    traced_script_module = torch.jit.trace(model, example)
    
    traced_script_module.save('model.pt')
    
    通过Annotation将Model转换为Torch Script

    在某些情况下,例如,如果模型使用特定形式的控制流,如果想要直接在Torch Script中编写模型并相应地标注(annotate)模型。例如,假设有以下普通的 Pytorch模型:

    import torch
    
    class MyModule(torch.nn.Module):
        def __init__(self, N, M):
            super(MyModule, self).__init__()
            self.weight = torch.nn.Parameter(torch.rand(N, M))
    
        def forward(self, input):
            if input.sum() > 0:
              output = self.weight.mv(input)
            else:
              output = self.weight + input
            return output
    

    由于此模块的forward方法使用依赖于输入的控制流,因此它不适合利用Tracing的方法生成Torch Script。为此,可以通过继承torch.jit.ScriptModule并将@ torch.jit.script_method标注添加到模型的forward中的方法,来将model转换为ScriptModule:

    import torch
    
    class MyModule(torch.jit.ScriptModule):
        def __init__(self, N, M):
            super(MyModule, self).__init__()
            self.weight = torch.nn.Parameter(torch.rand(N, M))
    
        @torch.jit.script_method
        def forward(self, input):
            if input.sum() > 0:
              output = self.weight.mv(input)
            else:
              output = self.weight + input
            return output
    
    my_script_module = MyModule()
    

    现在,创建一个新的MyModule对象会直接生成一个可序列化的ScriptModule实例了。

    后面就可以参考之前的文章(windows+VS2019+PyTorchLib配置使用攻略
    )进行调用了~~~( •̀ ω •́ )y

    [参考链接]
    https://pytorch.apachecn.org/docs/1.0/cpp_export.html

    相关文章

      网友评论

          本文标题:C++ windows调用ubuntu训练的PyTorch模型(

          本文链接:https://www.haomeiwen.com/subject/tfvzsktx.html