美文网首页
在PyTorch中由Python端调用C端底层函数

在PyTorch中由Python端调用C端底层函数

作者: CPinging | 来源:发表于2021-05-15 22:36 被阅读0次

今天看到一篇非常好的知乎blog,学到了如何在PyTorch下由python端的代码调用C代码。

一、目标

在DNN训练的过程中为了从Python端调用C的代码,方便接下去的科研。

二、前文

参考:https://zhuanlan.zhihu.com/p/358778742

在PyTorch的框架中我们能在下图的文件夹中找到load函数:

image.png

在框架中是这么描述这个函数的:Loads a PyTorch C++ extension just-in-time (JIT).
即使用即时编译将Python与C联系起来,并且是在python代码运行的过程中系统自动编译。

  • 这里要注意的地方是代码中要用pybind11进行呼应。下文细讲

三、内容

即时编译涉及如下文件:

  • 1 Python文件(主文件)

  • 2 C++头文件

  • 3 Cpp文件

  • 4 Cuda文件

在python中需要加入如下代码:

cuda_module = load(name="add2",
                           extra_include_paths=["/home/cping/PyTorch_scripts/Cuda_JIT_Test/NN-CUDA-Example/include"],
                           sources=["/home/cping/PyTorch_scripts/Cuda_JIT_Test/NN-CUDA-Example/pytorch/add2_ops.cpp", "/home/cping/PyTorch_scripts/Cuda_JIT_Test/NN-CUDA-Example/kernel/add2_kernel.cu"],
                           verbose=True)

其中add2就是个名字,可以换。但是extra_include_paths传入的是头文件的位置.h文件sources要给定路径,这里传入cpp以及cu文件。

verbose代表开启日志。
下面是例子:

Example:
        >>> from torch.utils.cpp_extension import load
        >>> module = load(
                name='extension',
                sources=['extension.cpp', 'extension_kernel.cu'],
                extra_cflags=['-O2'],
                verbose=True)

之后在python文件中调用:cuda_module.torch_launch_add2(c, a, b, n)

其中torch_launch_add2()为cpp中的函数

现在来看cpp代码:

#include <torch/extension.h>
#include "add2.h"

void torch_launch_add2(torch::Tensor &c,
                       const torch::Tensor &a,
                       const torch::Tensor &b,
                       int64_t n) {
    launch_add2((float *)c.data_ptr(),
                (const float *)a.data_ptr(),
                (const float *)b.data_ptr(),
                n);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("torch_launch_add2",
          &torch_launch_add2,
          "add2 kernel warpper");
}

这里我们要包含两个.h文件,其中add2.h是我们diy代码的头文件。

也就只有一句话;

void launch_add2(float *c,
                 const float *a,
                 const float *b,
                 int n);

回到cpp文件中,这里需要注意一个地方:

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("torch_launch_add2",
          &torch_launch_add2,
          "add2 kernel warpper");
}

这个地方必须要加,可以理解为python找c的入口(python中的cuda_module.torch_launch_add2(c, a, b, n)),而这个函数就是cpp中的函数。

之后调用launch_add2函数。而这个函数在cu文件中实现:

__global__ void add2_kernel(float* c,
                            const float* a,
                            const float* b,
                            int n) {
    for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
            i < n; i += gridDim.x * blockDim.x) {
        c[i] = a[i] + b[i];
    }
}

void launch_add2(float* c,
                 const float* a,
                 const float* b,
                 int n) {
    dim3 grid((n + 1023) / 1024);
    dim3 block(1024);
    add2_kernel<<<grid, block>>>(c, a, b, n);
}

至此,逻辑到位,直接python3 xxx.py就可以运行,系统自动编译运算,如下图。

image.png

相关文章

  • 在PyTorch中由Python端调用C端底层函数

    今天看到一篇非常好的知乎blog,学到了如何在PyTorch下由python端的代码调用C代码。 一、目标 在DN...

  • TCP

    [toc] TCP TCP总流程 客户端、服务端:调用socket函数,创建套接字描述符 服务端调用bind函数,...

  • GO调用C函数

    GO调用C函数 在很多场景下,在Go的程序中需要调用c函数或者是用c编写的库(底层驱动,算法等,不想用Go语言再去...

  • GO微服务入门:什么是微服务

    微服务入门 理解RPC 像调用本地函数一样,调用远程函数。 GO Socket通信 server端 client端...

  • RPC框架

    RPC框架要做到的最基本的三件事: 1、服务端如何确定客户端要调用的函数(服务寻址); 在远程调用中,客户端和服务...

  • WebViewJavascriptBridge 调用过程(一)

    核心思想 1、JS端和OC端各生成一个全局的bridge来处理函数调用和回调函数调用。2、JS端的匿名函数对应OC...

  • Java NIO(一)-I/O模型: 阻塞、非阻塞、I/O复用、

    目的## 为后期学习 Netty框架打好理论基础,并且在分布式RPC 服务中对客户端与服务端之间服务的调用,底层数...

  • WebViewJavascriptBridge 调用过程(二)

    JS调用OC过程 以WKWebView为例1、OC端注册 2、JS端调用OC端注册的名称,并传参数设置回调函数。 ...

  • RPC概述

    1 基本介绍 1.1客户端介绍: 1.1.1 客户端代理 RPC 要求像调用本地函数一样来调用远程函数,所以需要对...

  • PY08-04:Python加载动态库

      Python加载动态库主要用于使用C/C++弥补Python的性能,这个主题解决了Python调用动态库中函数...

网友评论

      本文标题:在PyTorch中由Python端调用C端底层函数

      本文链接:https://www.haomeiwen.com/subject/ycwkjltx.html