Pytorch和CUDA联合编程的基本步骤

参考链接-w3school

参考链接-CSDN

本文代码-github

背景

目前PyTorch已经提供了丰富的接口可以直接调用,但是仍存在一些高度自定义的操作无法使用PyToch或者Python高效的完成,因此PyTorch还提供了使用C++和CUDA编程的扩展接口。C++扩展主要有两种形式,一种是使用setuptools提前构建,也可以通过torch.utils.cpp_extension.load()在运行时构建。下面仅介绍第一种方法,第二种方法之后再学习。

基本步骤

Pytorch,CUDA,C++联合编程的一般步骤如下:

  1. 首先需要定义一个C++文件,该文件声明了CUDA文件中定义的函数,还需要进行一些检查,并最终将其调用转发给.cu文件。此外,该文件还需要声明将在Python中调用的函数,并使用pybind11绑定到python。OpenPCDet将上述步骤划分为以下几个步骤:
    1. 首先定义一个头文件,该头文件.h中包含了.cu文件中定义的函数和.cpp文件中定义的函数.

    2. 然后定义.cpp文件,其中函数的作用是负责进行一些检查和调用.cu文件中定义的函数

    3. .cu文件是负责执行具体的CUDA编程的操作

    4. api文件是将.cpp文件中定义的函数和PYBIND11绑定,以便Python调用

  2. setup.py文件中声明将要编译的模块名称,源文件路径等。
  3. 使用import导入声明的模块,使用Python实现其前向和反向传播的计算。

举例说明

.
├── ball_query_src
│   ├── api.cpp
│   ├── ball_query.cpp
│   ├── ball_query_cuda.cu
│   ├── ball_query_cuda.h
│   └── cuda_utils.h
├── setup.py
└── test_ball_query.py

项目的目录如上图所示,其中api.cpp文件是将ball_query.cpp声明的函数使用PYBIND11与python进行绑定。其中api.cpp的内容如下:

#include <torch/serialize/tensor.h>
#include <torch/extension.h>

#include "ball_query_cuda.h"


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    // 第一个参数表示的是在python中调用的名称,第二个参数是对应的cpp函数,第三个参数对应的是这个函数的说明
    m.def("ball_query_wrapper", &ball_query_wrapper_fast, "ball_query_wrapper_fast");
}

这里,作者为了使得代码的结构更加清晰,其在ball_query.h文件中分别声明了两个函数,一个是在C++中被调用的函数,另一个是在CUDA中实现的函数。其具体内容如下:

#ifndef _BALL_QUERY_GPU_H
#define _BALL_QUERY_GPU_H

#include <torch/serialize/tensor.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime_api.h>

// 与pybind11绑定的函数,其主要作用是调用下面的cuda函数
int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, 
    at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor);

// CUDA文件中的函数
void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, 
    const float *xyz, const float *new_xyz, int *idx);

#endif

.h文件中声明了上述两个文件之后,再分别再ball_query.cppball_query_cuda.cu文件中完成这两个函数的具体实现。

#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <cuda.h>
#include <cuda_runtime_api.h>

#include "ball_query_cuda.h"

extern THCState *state;

// 定义检查数据类型的宏
#define CHECK_CUDA(x) do { \
      if (!x.type().is_cuda()) { \
              fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
              exit(-1); \
            } \
} while (0)
#define CHECK_CONTIGUOUS(x) do { \
      if (!x.is_contiguous()) { \
              fprintf(stderr, "%s must be contiguous tensor at %s:%d\n", #x, __FILE__, __LINE__); \
              exit(-1); \
            } \
} while (0)
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)

// 完成输入数据类型的检查,同时调用cu文件中定义的函数
int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, 
    at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) {
    CHECK_INPUT(new_xyz_tensor);
    CHECK_INPUT(xyz_tensor);
    const float *new_xyz = new_xyz_tensor.data<float>();
    const float *xyz = xyz_tensor.data<float>();
    int *idx = idx_tensor.data<int>();
    
    ball_query_kernel_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx);
    return 1;
}

ball_query_cuda.cu实现

#include <math.h>
#include <stdio.h>
#include <stdlib.h>

#include "ball_query_cuda.h"
#include "cuda_utils.h"


__global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int nsample, 
    const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) {
    // new_xyz: (B, M, 3)
    // xyz: (B, N, 3)
    // output:
    //      idx: (B, M, nsample)
    int bs_idx = blockIdx.y;
    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (bs_idx >= b || pt_idx >= m) return;

    new_xyz += bs_idx * m * 3 + pt_idx * 3;
    xyz += bs_idx * n * 3;
    idx += bs_idx * m * nsample + pt_idx * nsample;

    float radius2 = radius * radius;
    float new_x = new_xyz[0];
    float new_y = new_xyz[1];
    float new_z = new_xyz[2];

    int cnt = 0;
    for (int k = 0; k < n; ++k) {
        float x = xyz[k * 3 + 0];
        float y = xyz[k * 3 + 1];
        float z = xyz[k * 3 + 2];
        float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
        if (d2 < radius2){
            if (cnt == 0){
                for (int l = 0; l < nsample; ++l) {
                    idx[l] = k;
                }
            }
            idx[cnt] = k;
            ++cnt;
            if (cnt >= nsample) break;
        }
    }
}


void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \
    const float *new_xyz, const float *xyz, int *idx) {
    // new_xyz: (B, M, 3)
    // xyz: (B, N, 3)
    // output:
    //      idx: (B, M, nsample)

    cudaError_t err;

    dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b);  // blockIdx.x(col), blockIdx.y(row)
    dim3 threads(THREADS_PER_BLOCK);

    ball_query_kernel_fast<<<blocks, threads>>>(b, n, m, radius, nsample, new_xyz, xyz, idx);
    // cudaDeviceSynchronize();  // for using printf in kernel function
    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}

至此,ball_query的核心功能已经完成,然后我们需要使用setup.py文件来编译上述文件。setup.py文件的具体实现如下:

import os
import subprocess

from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension


def make_cuda_ext(name, module, sources):
    cuda_ext = CUDAExtension(
        name='%s.%s' % (module, name),
        sources=[os.path.join(module.split('.')[-1], src) for src in sources]
    )
    print([os.path.join(*module.split('.'), src) for src in sources])
    return cuda_ext

if __name__ == '__main__':
    setup(
        name='ballquery',
        packages=find_packages(),
        ext_modules=[
            CUDAExtension('ball_query_cuda',[
                'ball_query_src/api.cpp',
                'ball_query_src/ball_query.cpp',
                'ball_query_src/ball_query_cuda.cu',  
            ])
        ],
        cmdclass={
            'build_ext': BuildExtension
        }
    )

至此,我们已经生成了上述代码的链接库,但是如果需要将其嵌入到神经网络中,还需要定义其前向传播和反向传播方法。这里我们在test_ball_query.py文件中完成其前向传播和反向传播。

import torch
import torch.nn as nn
from torch.autograd import Function, Variable
import math

import ball_query_cuda

# 定义该方法的前向传播和反向传播方法
class BallQuery(Function):
    
    @staticmethod
    def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor:
        """
        :param ctx:
        :param radius: float, radius of the balls
        :param nsample: int, maximum number of features in the balls
        :param xyz: (B, N, 3) xyz coordinates of the features
        :param new_xyz: (B, npoint, 3) centers of the ball query
        :return:
            idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls
        """
        assert new_xyz.is_contiguous()
        assert xyz.is_contiguous()

        B, N, _ = xyz.size()
        npoint = new_xyz.size(1)
        idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()
        
        ball_query_cuda.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx)
        return idx

    @staticmethod
    def backward(ctx, a=None):
        return None, None, None, None


ball_query = BallQuery.apply

xyz = torch.randn(2, 128, 3).cuda()
new_xyz = xyz

result = ball_query(0.8, 3, xyz, new_xyz)

print(result.shape)
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,921评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 87,635评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 151,393评论 0 338
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,836评论 1 277
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,833评论 5 368
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,685评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,043评论 3 399
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,694评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 42,671评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,670评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,779评论 1 332
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,424评论 4 321
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,027评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,984评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,214评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,108评论 2 351
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,517评论 2 343

推荐阅读更多精彩内容