Sgemm [128,128,8]

主要参考 论文 Huang, 2018 (arxiv.org)

性能可达到 cublas的 96%
目前只贴下源码,注释还是蛮多的。
之前搞了个分支,速度下降了10%,有些过分了。

#include <algorithm>
#include <cublas_v2.h>
#include <cuda_device_runtime_api.h>
#include <device_launch_parameters.h>
#include <iomanip>
#include <iostream>
#include <random>
#include <stdio.h>
#include <stdlib.h>
#include <string>
#include <thrust/device_vector.h>
#include <thrust/functional.h>
#include <thrust/gather.h>
#include <thrust/host_vector.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/random.h>
#include <thrust/scan.h>
#include <vector>
/**
 * @brief
 *  gemm: C= alpha*A*B+beta*C
 * 每个 thread block计算 C 的tile是 [128*128]
 * 子迭代计算 Atile[128,8] Btile[8,128]
 * 所有矩阵按照行主序计算。
 * 是有数据预取,向量化读取等
 * block的threads是 256个
 * 一个 warp要计算的tile是 [32,64]
 * 一个 thread计算多 tile是 [8,8]
 * 因为要读取合并,对C采用循环分发的方式
 */

/**********************/
/* cuBLAS ERROR CHECK */
/**********************/
#ifndef cublasSafeCall
#define cublasSafeCall(err) __cublasSafeCall(err, __FILE__, __LINE__)
#endif

inline void __cublasSafeCall(cublasStatus_t err, const char *file,
                             const int line)
{
    if (CUBLAS_STATUS_SUCCESS != err)
    {
        fprintf(stderr,
                "CUBLAS error in file '%s', line %d\n \nerror %d \nterminating!\n",
                __FILE__, __LINE__, err);
        // getch();
        cudaDeviceReset();
        assert(0);
    }
}
template <typename T>
struct Type4;

template <>
struct Type4<float>
{
    using type = float4;
};

template <typename T>
using Type4t = typename Type4<T>::type;

#define A(i, j) A[(i)*lda + (j)]
#define B(i, j) B[(i)*ldb + (j)]
#define C(i, j) C[(i)*ldc + (j)]

#define ptrA(i, j) ptrA[(i)*lda + (j)]
#define ptrB(i, j) ptrB[(i)*ldb + (j)]
#define MS 128
#define NS 128
#define KS 8

template <typename T>
__device__ __forceinline__ void
vscal_fma(Type4t<T> &dst_vec, const Type4t<T> &src_vec, const T &scale)
{
    dst_vec.x += src_vec.x * scale;
    dst_vec.y += src_vec.y * scale;
    dst_vec.z += src_vec.z * scale;
    dst_vec.w += src_vec.w * scale;
}

template <typename T>
__device__ __forceinline__ void simd_axpby(Type4t<T> &dst_vec, T alpha,
                                           const Type4t<T> &srca_vec, T beta,
                                           const Type4t<T> &srcb_vec)
{
    dst_vec.x = alpha * srca_vec.x + beta * srcb_vec.x;
    dst_vec.y = alpha * srca_vec.y + beta * srcb_vec.y;
    dst_vec.z = alpha * srca_vec.z + beta * srcb_vec.z;
    dst_vec.w = alpha * srca_vec.w + beta * srcb_vec.w;
}

template <typename T>
__device__ __forceinline__ void vload(Type4t<T> &dst_vec, const T *addr)
{
    dst_vec = *((Type4t<T> *)(addr));
}

template <typename T>
__device__ __forceinline__ void vstore(T *addr, const Type4t<T> &src_vec)
{
    *((Type4t<T> *)(addr)) = src_vec;
}

__device__ __forceinline__ void print(const float4 &vec)
{
    printf("%f, %f, %f, %f\n", vec.x, vec.y, vec.z, vec.w);
}
/**
 * template <typename AccessType>
struct global_load<AccessType,
                   16
                  > {
  CUTLASS_DEVICE
  global_load(AccessType &D, void const *ptr, bool pred_guard) {
  uint4 &data = reinterpret_cast<uint4 &>(D);
    asm volatile(
        "{\n"
        "  .reg .pred p;\n"
        "  setp.ne.b32 p, %5, 0;\n"
        "  mov.b32 %0, %6;\n"
        "  mov.b32 %1, %7;\n"
        "  mov.b32 %2, %8;\n"
        "  mov.b32 %3, %9;\n"
#if CUTLASS_ENABLE_L2_PREFETCH
        "  @p ld.global.L2::128B.v4.u32 {%0, %1, %2, %3}, [%4];\n"
#else
        "  @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n"
#endif
        "}\n"
        : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
        : "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y), "r"(data.z),
"r"(data.w));
  }
};

**/
__global__ __launch_bounds__(256)
    // assert M%128 ==0 && N%128 ==0 && K%8 ==0
    void gemm_kernel(int M, int N, int K, float alpha, const float *A,
                     const float *B, float beta, float *C)
{

    const int lda = K;
    const int ldb = N;
    const int ldc = N;
    const int tx = threadIdx.x;
    const int bx = blockIdx.x, by = blockIdx.y;
    // bx 是对应col,by是对应 row 的 行主序

    const int warp_id = tx >> 5;
    const int lane_id = tx & 31;
    // 共 8个warp, 分成 4行 2列
    // 因为 一个warp处理的元素是[32,64] 128/32 = 4,128/64=2 ,所以需要
    // 4行warp,2列 warp。
    const int warp_col = warp_id & 1;
    const int warp_row = warp_id >> 1;
    // 一个warp负责 4* 8,8* 8 的 Ctile //但需要两次读取
    //  lane_id 分成 [4,8]的方块布局
    const int col_w = lane_id & 7;
    const int row_w = lane_id >> 3;
    //每个 tx 首先由 8行4列的C tile,再循环分发一次,共 8行8列
    // 因为 C是行主序,col索引是连续的,连续的col索引应该给相邻的tx(float4 是
    // 16B,不满足32B的最小L1,L2内存交易单位)
    // C的列索引需要循环分发,行索引不需要循环分发[ col0-col3 ,col64-col67:
    // row0-row7], 一个warp是 [32,64] 所以 rowc 中 warp_row<< 5 ,warp_col<<6
    //但列需要循环分发,要分成两次读取,因此这里是  (warp_col << 5) + (col_w << 2)
    //;
    const int row_c = (warp_row << 5) + (row_w << 3); // row0-row7
    const int col_c = (warp_col << 5) + (col_w << 2); // col0-col3 ;col64-col67
    // col_w<< 2 是因为 一个tx一次读写or 4列。需要循环分发,col0-col3 ;col64-col67

    // rowa rowb cola colb 用于从global memory读取到share memory
    // Atile [128,8] 每个线程读写到 share memory是 4个值,
    // 256个 tx 分成 [128,2]布局
    const int col_a = (tx & 1) << 2;
    const int row_a = tx >> 1;
    // Btile [8,128] 每个线程读写到 share memory是 4个值
    // 256个 tx 分成 [8,32] 布局
    const int col_b = (tx & 31) << 2; // 4个值所以是 <<2
    const int row_b = tx >> 5;

    // 该block处理的 A,B,C相对地址
    A += (by << 7) * lda;
    B += (bx << 7);
    C += (by << 7) * ldc + (bx << 7);

    __shared__ float
        smema[2][8][128]; // Atile[128][8]但需要转置以方便计算的时候读取A的4行1列
                          // //8行1列
    __shared__ float smemb[2][8]
                          [128]; // Btile[8][128] 计算的时候读取 B的 1行4列  //
                                 // 1行 8列 // 2是双缓存策略以减少 一次sync
    // auto *ptr_smema = &smema[0];
    // float *[8][128] type(ptr_smema)
    // auto *ptr_smemb = &smemb[0];
    float4 Av1[2], Av2[2], Bv1[2], Bv2[2], Cv[16], Cres[16];
    // Av[2]正好8个值,Av1[0],Av1[1]用于预取策略的交换。 Cv[16]共 16*4= 8*8个值。
    float4 pref_Av, pref_Bv; //从global memory中预取的值
                             // 不直接使用 A,B 大概是因为
                             // 预取的时候用过Alia防止最后一次循环预取越界的情况。
    memset(Cres, 0, sizeof(Cres));
    // 循环之前先来次预取, 分别是 global memory的预取,和share memory的预取
    vload(pref_Av, &(A(row_a, col_a)));
    vload(pref_Bv, &(B(row_b, col_b)));
    int buffer_switch = 0; // two sharememoy buffer switch

    vstore(&(smemb[0][row_b][col_b]), pref_Bv);
    smema[0][col_a][row_a] = pref_Av.x;
    smema[0][col_a + 1][row_a] = pref_Av.y;
    smema[0][col_a + 2][row_a] = pref_Av.z;
    smema[0][col_a + 3][row_a] = pref_Av.w;
    // 写入到sharememoy后 进行sync

    __syncthreads();

    //读取 Atile [row_c,k] k(0-7), 因为转置,Atil的row是连续的
    //读取 Btile [k,col_c] col是循环分发的,分配到的 colc 0-3,64-67
    vload(Av1[0], &smema[0][0][row_c]);
    vload(Av2[0], &smema[0][0][row_c + 4]);
    vload(Bv1[0], &smemb[0][0][col_c]);
    vload(Bv2[0], &smemb[0][0][col_c + 64]);
    for (int global_k = KS; global_k < K; global_k += KS)
    {

        // global_k起始值是 KS, 因为循环之前已经预取了一次,为避免global
        // access越界,这里只到global_k<K,也可以global_k<=K,然后 取余 内存值
        // global_k%K 把其限制在 [0,K)里
        //但这样会计算 k_iteration-1次,还有一次剩余的需要计算
        // main loop 在 K上循环,每次迭代 KS

        //读取下一次计算的global memory
        //重新定位 ptrA,ptrB等

        // 这么大块的语句会不会分支跳转??
        A += KS;
        // global_K为上移动
        B += (KS * ldb);
        vload(pref_Av, &(A(row_a, col_a)));
        // global memory 的读取实际上一个tx 一次,[128,8] [8,128]
        // (整个Block的tx映射到 [128,8] [8,128])
        vload(pref_Bv, &(B(row_b, col_b)));
        // 最后一次 迭代就会越界, 这里条件判断还是用  @p
        // 比较靠谱,不如直接ptx汇编,要不类似 cutlass那样inline ptx gloabal load

        int reg_switch = 0;
#pragma unroll
        for (int inner_k_count = 0; inner_k_count < KS; inner_k_count++)
        {
            int next_inner_k_count = (inner_k_count + 1) & 7; //(1...7;1)取余
            // prefecth data from smem  to register for nex iter compute
            int next_reg = reg_switch ^ 1;
            vload(Av1[next_reg], &smema[buffer_switch][next_inner_k_count][row_c]);
            vload(Av2[next_reg],
                  &smema[buffer_switch][next_inner_k_count][row_c + 4]);
            vload(Bv1[next_reg], &smemb[buffer_switch][next_inner_k_count][col_c]);
            vload(Bv2[next_reg],
                  &smemb[buffer_switch][next_inner_k_count][col_c + 64]);
            // next_inner_k_count& 1用以切换预取的register
            //行主序,一行是连续的则 索引col是连续的
            vscal_fma(Cres[0], Bv1[reg_switch], Av1[reg_switch].x);

            vscal_fma(Cres[1], Bv1[reg_switch],
                      Av1[reg_switch].y); // [rowc+1,colc0-colc3]
            vscal_fma(Cres[2], Bv1[reg_switch],
                      Av1[reg_switch].z); // [rowc+2,colc0-colc3]
            vscal_fma(Cres[3], Bv1[reg_switch],
                      Av1[reg_switch].w); // [rowc+3,colc0-colc3]
            vscal_fma(Cres[4], Bv1[reg_switch], Av2[reg_switch].x);
            // [row_c+4,col_c0],[row_c,colc_1],[row_c,colc_2],[row_c,colc_3]
            vscal_fma(Cres[5], Bv1[reg_switch],
                      Av2[reg_switch].y); // [rowc+5,colc0-colc3]
            vscal_fma(Cres[6], Bv1[reg_switch],
                      Av2[reg_switch].z); // [rowc+6,colc0-colc3]
            vscal_fma(Cres[7], Bv1[reg_switch],
                      Av2[reg_switch].w); // [rowc+7,colc0-colc3]
            vscal_fma(Cres[8], Bv2[reg_switch], Av1[reg_switch].x);

            // [row_c,col_c64],[row_c,colc_65],[row_c,colc_66],[row_c,colc_67]
            vscal_fma(Cres[9], Bv2[reg_switch],
                      Av1[reg_switch].y); // [rowc+1,colc64-colc67]
            vscal_fma(Cres[10], Bv2[reg_switch],
                      Av1[reg_switch].z); // [rowc+2,colc64-colc67]
            vscal_fma(Cres[11], Bv2[reg_switch],
                      Av1[reg_switch].w); // [rowc+3,colc64-colc67]
            vscal_fma(
                Cres[12], Bv2[reg_switch],
                Av2[reg_switch]
                    .x); // [row_c+4,col_c64],[row_c,colc_65],[row_c,colc_66],[row_c,colc_67]
            vscal_fma(Cres[13], Bv2[reg_switch],
                      Av2[reg_switch].y); // [rowc+5,colc64-colc67]
            vscal_fma(Cres[14], Bv2[reg_switch],
                      Av2[reg_switch].z); // [rowc+6,colc64-colc67]
            vscal_fma(Cres[15], Bv2[reg_switch],
                      Av2[reg_switch].w); // [rowc+7,colc64--colc67]
            reg_switch ^= 1;
        }

        buffer_switch ^= 1;
        // two sharememoy buffer switch
        // store memoy in buffer
        vstore(&(smemb[buffer_switch][row_b][col_b]), pref_Bv);
        smema[buffer_switch][col_a][row_a] = pref_Av.x;
        smema[buffer_switch][col_a + 1][row_a] = pref_Av.y;
        smema[buffer_switch][col_a + 2][row_a] = pref_Av.z;
        smema[buffer_switch][col_a + 3][row_a] = pref_Av.w;
        __syncthreads();
        //从 sharememoy 读值
        vload(Av1[0], &smema[buffer_switch][0][row_c]);
        vload(Av2[0], &smema[buffer_switch][0][row_c + 4]);
        vload(Bv1[0], &smemb[buffer_switch][0][col_c]);
        vload(Bv2[0], &smemb[buffer_switch][0][col_c + 64]);

        // 为下一次子循环预取值
    }

// 这个版本去掉了分支,然后手动加一次计算子循环,性能提升了10%
int reg_switch = 0;
#pragma unroll
        for (int inner_k_count = 0; inner_k_count < KS; inner_k_count++)
        {
            int next_inner_k_count = (inner_k_count + 1) & 7; //(1...7;1)取余
            // prefecth data from smem  to register for nex iter compute
            int next_reg = reg_switch ^ 1;
            vload(Av1[next_reg], &smema[buffer_switch][next_inner_k_count][row_c]);
            vload(Av2[next_reg],
                  &smema[buffer_switch][next_inner_k_count][row_c + 4]);
            vload(Bv1[next_reg], &smemb[buffer_switch][next_inner_k_count][col_c]);
            vload(Bv2[next_reg],
                  &smemb[buffer_switch][next_inner_k_count][col_c + 64]);
            // next_inner_k_count& 1用以切换预取的register
            //行主序,一行是连续的则 索引col是连续的
            vscal_fma(Cres[0], Bv1[reg_switch], Av1[reg_switch].x);

            vscal_fma(Cres[1], Bv1[reg_switch],
                      Av1[reg_switch].y); // [rowc+1,colc0-colc3]
            vscal_fma(Cres[2], Bv1[reg_switch],
                      Av1[reg_switch].z); // [rowc+2,colc0-colc3]
            vscal_fma(Cres[3], Bv1[reg_switch],
                      Av1[reg_switch].w); // [rowc+3,colc0-colc3]
            vscal_fma(Cres[4], Bv1[reg_switch], Av2[reg_switch].x);
            // [row_c+4,col_c0],[row_c,colc_1],[row_c,colc_2],[row_c,colc_3]
            vscal_fma(Cres[5], Bv1[reg_switch],
                      Av2[reg_switch].y); // [rowc+5,colc0-colc3]
            vscal_fma(Cres[6], Bv1[reg_switch],
                      Av2[reg_switch].z); // [rowc+6,colc0-colc3]
            vscal_fma(Cres[7], Bv1[reg_switch],
                      Av2[reg_switch].w); // [rowc+7,colc0-colc3]
            vscal_fma(Cres[8], Bv2[reg_switch], Av1[reg_switch].x);

            // [row_c,col_c64],[row_c,colc_65],[row_c,colc_66],[row_c,colc_67]
            vscal_fma(Cres[9], Bv2[reg_switch],
                      Av1[reg_switch].y); // [rowc+1,colc64-colc67]
            vscal_fma(Cres[10], Bv2[reg_switch],
                      Av1[reg_switch].z); // [rowc+2,colc64-colc67]
            vscal_fma(Cres[11], Bv2[reg_switch],
                      Av1[reg_switch].w); // [rowc+3,colc64-colc67]
            vscal_fma(
                Cres[12], Bv2[reg_switch],
                Av2[reg_switch]
                    .x); // [row_c+4,col_c64],[row_c,colc_65],[row_c,colc_66],[row_c,colc_67]
            vscal_fma(Cres[13], Bv2[reg_switch],
                      Av2[reg_switch].y); // [rowc+5,colc64-colc67]
            vscal_fma(Cres[14], Bv2[reg_switch],
                      Av2[reg_switch].z); // [rowc+6,colc64-colc67]
            vscal_fma(Cres[15], Bv2[reg_switch],
                      Av2[reg_switch].w); // [rowc+7,colc64--colc67]
            reg_switch ^= 1;
        }

    // 上面在global_k上的主循环实际上少迭代一次,因为预取相关的问题,对全局内存不能越界
    // load Ctile and accumulate the Cres
    vload(Cv[0], &C(row_c, col_c));
    vload(Cv[1], &C(row_c + 1, col_c));
    vload(Cv[2], &C(row_c + 2, col_c));
    vload(Cv[3], &C(row_c + 3, col_c));
    vload(Cv[4], &C(row_c + 4, col_c));
    vload(Cv[5], &C(row_c + 5, col_c));
    vload(Cv[6], &C(row_c + 6, col_c));
    vload(Cv[7], &C(row_c + 7, col_c));
    vload(Cv[8], &C(row_c, col_c + 64));
    vload(Cv[9], &C(row_c + 1, col_c + 64));
    vload(Cv[10], &C(row_c + 2, col_c + 64));
    vload(Cv[11], &C(row_c + 3, col_c + 64));
    vload(Cv[12], &C(row_c + 4, col_c + 64));
    vload(Cv[13], &C(row_c + 5, col_c + 64));
    vload(Cv[14], &C(row_c + 6, col_c + 64));
    vload(Cv[15], &C(row_c + 7, col_c + 64));

    simd_axpby(Cres[0], alpha, Cres[0], beta, Cv[0]);
    simd_axpby(Cres[1], alpha, Cres[1], beta, Cv[1]);
    simd_axpby(Cres[2], alpha, Cres[2], beta, Cv[2]);
    simd_axpby(Cres[3], alpha, Cres[3], beta, Cv[3]);

    simd_axpby(Cres[4], alpha, Cres[4], beta, Cv[4]);
    simd_axpby(Cres[5], alpha, Cres[5], beta, Cv[5]);
    simd_axpby(Cres[6], alpha, Cres[6], beta, Cv[6]);
    simd_axpby(Cres[7], alpha, Cres[7], beta, Cv[7]);

    simd_axpby(Cres[8], alpha, Cres[8], beta, Cv[8]);
    simd_axpby(Cres[9], alpha, Cres[9], beta, Cv[9]);
    simd_axpby(Cres[10], alpha, Cres[10], beta, Cv[10]);
    simd_axpby(Cres[11], alpha, Cres[11], beta, Cv[11]);

    simd_axpby(Cres[12], alpha, Cres[12], beta, Cv[12]);
    simd_axpby(Cres[13], alpha, Cres[13], beta, Cv[13]);
    simd_axpby(Cres[14], alpha, Cres[14], beta, Cv[14]);
    simd_axpby(Cres[15], alpha, Cres[15], beta, Cv[15]);

    vstore(&C(row_c, col_c), Cres[0]);
    vstore(&C(row_c + 1, col_c), Cres[1]);
    vstore(&C(row_c + 2, col_c), Cres[2]);
    vstore(&C(row_c + 3, col_c), Cres[3]);
    vstore(&C(row_c + 4, col_c), Cres[4]);
    vstore(&C(row_c + 5, col_c), Cres[5]);
    vstore(&C(row_c + 6, col_c), Cres[6]);
    vstore(&C(row_c + 7, col_c), Cres[7]);

    vstore(&C(row_c, col_c + 64), Cres[8]);
    vstore(&C(row_c + 1, col_c + 64), Cres[9]);
    vstore(&C(row_c + 2, col_c + 64), Cres[10]);
    vstore(&C(row_c + 3, col_c + 64), Cres[11]);
    vstore(&C(row_c + 4, col_c + 64), Cres[12]);
    vstore(&C(row_c + 5, col_c + 64), Cres[13]);
    vstore(&C(row_c + 6, col_c + 64), Cres[14]);
    vstore(&C(row_c + 7, col_c + 64), Cres[15]);
}
// 3750.55 3591.11 3529.28 3460.77
template <typename T>
bool verify_res(size_t m, size_t n, const thrust::device_vector<T> &ref_data,
                const thrust::device_vector<T> &res_data,
                T abs_error = T(1e-2))
{
    thrust::host_vector<T> href_data = ref_data;
    thrust::host_vector<T> hres_data = res_data;

    T max_error = std::numeric_limits<T>::lowest();
    int num_errors = 0;
    for (size_t i = 0; i < m; i++)
    {
        for (size_t j = 0; j < n; j++)
        {
            auto tmp_error = std::abs(hres_data[i * n + j] - href_data[i * n + j]);
            // std::cout<<tmp_error<<"\n";
            if (tmp_error > abs_error)
            {
                num_errors++;
                max_error = max_error < tmp_error ? tmp_error : max_error;
            }
        }
    }
    std::cout << "num_error: " << num_errors << " max error= " << max_error
              << " \n";
    return num_errors == 0;
}

template <typename T>
void host_gemm(int M, int N, int K, std::vector<T> &A, std::vector<T> &B,
               std::vector<T> &C, T alpha, T beta)
{
    for (int m = 0; m < M; m++)
    {
        for (int n = 0; n < N; n++)
        {
            T accum = 0;
            for (int k = 0; k < K; k++)
            {
                accum += A[m * K + k] * B[k * N + n];
            }
            C[m * N + n] = alpha * accum + beta * C[m * N + n];
        }
    }
}

// print an M-by-N array
template <typename T>
void print(size_t m, size_t n, thrust::device_vector<T> &d_data)
{
    thrust::host_vector<T> h_data = d_data;

    for (size_t i = 0; i < m; i++)
    {
        for (size_t j = 0; j < n; j++)
            std::cout << std::setw(1) << h_data[i * n + j] << " ";
        std::cout << "\n";
    }
}

int main(int argc, char **argv)
{
    const int M = 6144;
    const int N = 6144;
    const int K = 6144;
    constexpr int Ms = 128;
    constexpr int Ns = 128;
    constexpr int Ks = 8;
    using Element = float;
    std::vector<Element> hA(M * K);
    std::vector<Element> hB(K * N);
    std::vector<Element> hC(M * N);
    std::random_device rd;  // 将用于获得随机数引擎的种子
    std::mt19937 gen(rd()); // 以 rd() 播种的标准 mersenne_twister_engine
    std::uniform_real_distribution<Element> dis(1, 10);
    std::generate(hA.begin(), hA.end(), [&rd, &gen, &dis]()
                  { return dis(gen); });
    std::generate(hB.begin(), hB.end(), [&rd, &gen, &dis]()
                  { return dis(gen); });
    thrust::device_vector<Element> dA = hA;
    thrust::device_vector<Element> dB = hB;
    thrust::device_vector<Element> dC(M * N);
    thrust::device_vector<Element> drefC(M * N);
    Element *dA_ptr = thrust::raw_pointer_cast(dA.data());
    Element *dB_ptr = thrust::raw_pointer_cast(dB.data());
    Element *dC_ptr = thrust::raw_pointer_cast(dC.data());
    Element *dCref_ptr = thrust::raw_pointer_cast(drefC.data());

    cublasHandle_t handle;
    cublasSafeCall(cublasCreate(&handle));
    float alpha = 1.;
    float beta = 0.;
    cublasSafeCall(cublasSgemm_v2(handle, CUBLAS_OP_N, CUBLAS_OP_N, N, M, K,
                                  &alpha, dB_ptr, N, dA_ptr, K, &beta, dCref_ptr,
                                  N));

    const dim3 block(256);
    const dim3 grid((M + 127) / 128, (N + 127) / 128);
    gemm_kernel<<<grid, block>>>(M, N, K, alpha, dA_ptr, dB_ptr, beta, dC_ptr);

    gemm_kernel<<<grid, block>>>(M, N, K, alpha, dA_ptr, dB_ptr, beta, dC_ptr);
    std::cout << cudaGetErrorString(cudaGetLastError()) << "\n";
    verify_res(M, N, dC, drefC);

    // host_gemm(M, N, K, hA, hB, hC, alpha, beta);
    // verify_res(M, N, thrust::device_vector<Element>(hC), drefC);
    // std::cout << "hc verify dc \n";
    // verify_res(M, N, thrust::device_vector<Element>(hC), dC);
    // print(M, N, dC);

    // print(M, N, drefC);
}

//
//
// nvcc -arch=sm_75 -O3   ./gemm.cu -o gemmtest -lcublas
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,732评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 87,496评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 151,264评论 0 338
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,807评论 1 277
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,806评论 5 368
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,675评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,029评论 3 399
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,683评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 41,704评论 1 299
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,666评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,773评论 1 332
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,413评论 4 321
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,016评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,978评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,204评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,083评论 2 350
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,503评论 2 343

推荐阅读更多精彩内容

  • 我是黑夜里大雨纷飞的人啊 1 “又到一年六月,有人笑有人哭,有人欢乐有人忧愁,有人惊喜有人失落,有的觉得收获满满有...
    陌忘宇阅读 8,520评论 28 53
  • 首先介绍下自己的背景: 我11年左右入市到现在,也差不多有4年时间,看过一些关于股票投资的书籍,对于巴菲特等股神的...
    瞎投资阅读 5,653评论 3 8
  • ![Flask](...
    极客学院Wiki阅读 7,229评论 0 3
  • 不知不觉易趣客已经在路上走了快一年了,感觉也该让更多朋友认识知道易趣客,所以就谢了这篇简介,已做创业记事。 易趣客...
    Physher阅读 3,407评论 1 2