主要参考 论文 Huang, 2018 (arxiv.org)
性能可达到 cublas的 96%
目前只贴下源码,注释还是蛮多的。
之前搞了个分支,速度下降了10%,有些过分了。
#include <algorithm>
#include <cublas_v2.h>
#include <cuda_device_runtime_api.h>
#include <device_launch_parameters.h>
#include <iomanip>
#include <iostream>
#include <random>
#include <stdio.h>
#include <stdlib.h>
#include <string>
#include <thrust/device_vector.h>
#include <thrust/functional.h>
#include <thrust/gather.h>
#include <thrust/host_vector.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/random.h>
#include <thrust/scan.h>
#include <vector>
/**
* @brief
* gemm: C= alpha*A*B+beta*C
* 每个 thread block计算 C 的tile是 [128*128]
* 子迭代计算 Atile[128,8] Btile[8,128]
* 所有矩阵按照行主序计算。
* 是有数据预取,向量化读取等
* block的threads是 256个
* 一个 warp要计算的tile是 [32,64]
* 一个 thread计算多 tile是 [8,8]
* 因为要读取合并,对C采用循环分发的方式
*/
/**********************/
/* cuBLAS ERROR CHECK */
/**********************/
#ifndef cublasSafeCall
#define cublasSafeCall(err) __cublasSafeCall(err, __FILE__, __LINE__)
#endif
inline void __cublasSafeCall(cublasStatus_t err, const char *file,
const int line)
{
if (CUBLAS_STATUS_SUCCESS != err)
{
fprintf(stderr,
"CUBLAS error in file '%s', line %d\n \nerror %d \nterminating!\n",
__FILE__, __LINE__, err);
// getch();
cudaDeviceReset();
assert(0);
}
}
template <typename T>
struct Type4;
template <>
struct Type4<float>
{
using type = float4;
};
template <typename T>
using Type4t = typename Type4<T>::type;
#define A(i, j) A[(i)*lda + (j)]
#define B(i, j) B[(i)*ldb + (j)]
#define C(i, j) C[(i)*ldc + (j)]
#define ptrA(i, j) ptrA[(i)*lda + (j)]
#define ptrB(i, j) ptrB[(i)*ldb + (j)]
#define MS 128
#define NS 128
#define KS 8
template <typename T>
__device__ __forceinline__ void
vscal_fma(Type4t<T> &dst_vec, const Type4t<T> &src_vec, const T &scale)
{
dst_vec.x += src_vec.x * scale;
dst_vec.y += src_vec.y * scale;
dst_vec.z += src_vec.z * scale;
dst_vec.w += src_vec.w * scale;
}
template <typename T>
__device__ __forceinline__ void simd_axpby(Type4t<T> &dst_vec, T alpha,
const Type4t<T> &srca_vec, T beta,
const Type4t<T> &srcb_vec)
{
dst_vec.x = alpha * srca_vec.x + beta * srcb_vec.x;
dst_vec.y = alpha * srca_vec.y + beta * srcb_vec.y;
dst_vec.z = alpha * srca_vec.z + beta * srcb_vec.z;
dst_vec.w = alpha * srca_vec.w + beta * srcb_vec.w;
}
template <typename T>
__device__ __forceinline__ void vload(Type4t<T> &dst_vec, const T *addr)
{
dst_vec = *((Type4t<T> *)(addr));
}
template <typename T>
__device__ __forceinline__ void vstore(T *addr, const Type4t<T> &src_vec)
{
*((Type4t<T> *)(addr)) = src_vec;
}
__device__ __forceinline__ void print(const float4 &vec)
{
printf("%f, %f, %f, %f\n", vec.x, vec.y, vec.z, vec.w);
}
/**
* template <typename AccessType>
struct global_load<AccessType,
16
> {
CUTLASS_DEVICE
global_load(AccessType &D, void const *ptr, bool pred_guard) {
uint4 &data = reinterpret_cast<uint4 &>(D);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %5, 0;\n"
" mov.b32 %0, %6;\n"
" mov.b32 %1, %7;\n"
" mov.b32 %2, %8;\n"
" mov.b32 %3, %9;\n"
#if CUTLASS_ENABLE_L2_PREFETCH
" @p ld.global.L2::128B.v4.u32 {%0, %1, %2, %3}, [%4];\n"
#else
" @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n"
#endif
"}\n"
: "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
: "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y), "r"(data.z),
"r"(data.w));
}
};
**/
__global__ __launch_bounds__(256)
// assert M%128 ==0 && N%128 ==0 && K%8 ==0
void gemm_kernel(int M, int N, int K, float alpha, const float *A,
const float *B, float beta, float *C)
{
const int lda = K;
const int ldb = N;
const int ldc = N;
const int tx = threadIdx.x;
const int bx = blockIdx.x, by = blockIdx.y;
// bx 是对应col,by是对应 row 的 行主序
const int warp_id = tx >> 5;
const int lane_id = tx & 31;
// 共 8个warp, 分成 4行 2列
// 因为 一个warp处理的元素是[32,64] 128/32 = 4,128/64=2 ,所以需要
// 4行warp,2列 warp。
const int warp_col = warp_id & 1;
const int warp_row = warp_id >> 1;
// 一个warp负责 4* 8,8* 8 的 Ctile //但需要两次读取
// lane_id 分成 [4,8]的方块布局
const int col_w = lane_id & 7;
const int row_w = lane_id >> 3;
//每个 tx 首先由 8行4列的C tile,再循环分发一次,共 8行8列
// 因为 C是行主序,col索引是连续的,连续的col索引应该给相邻的tx(float4 是
// 16B,不满足32B的最小L1,L2内存交易单位)
// C的列索引需要循环分发,行索引不需要循环分发[ col0-col3 ,col64-col67:
// row0-row7], 一个warp是 [32,64] 所以 rowc 中 warp_row<< 5 ,warp_col<<6
//但列需要循环分发,要分成两次读取,因此这里是 (warp_col << 5) + (col_w << 2)
//;
const int row_c = (warp_row << 5) + (row_w << 3); // row0-row7
const int col_c = (warp_col << 5) + (col_w << 2); // col0-col3 ;col64-col67
// col_w<< 2 是因为 一个tx一次读写or 4列。需要循环分发,col0-col3 ;col64-col67
// rowa rowb cola colb 用于从global memory读取到share memory
// Atile [128,8] 每个线程读写到 share memory是 4个值,
// 256个 tx 分成 [128,2]布局
const int col_a = (tx & 1) << 2;
const int row_a = tx >> 1;
// Btile [8,128] 每个线程读写到 share memory是 4个值
// 256个 tx 分成 [8,32] 布局
const int col_b = (tx & 31) << 2; // 4个值所以是 <<2
const int row_b = tx >> 5;
// 该block处理的 A,B,C相对地址
A += (by << 7) * lda;
B += (bx << 7);
C += (by << 7) * ldc + (bx << 7);
__shared__ float
smema[2][8][128]; // Atile[128][8]但需要转置以方便计算的时候读取A的4行1列
// //8行1列
__shared__ float smemb[2][8]
[128]; // Btile[8][128] 计算的时候读取 B的 1行4列 //
// 1行 8列 // 2是双缓存策略以减少 一次sync
// auto *ptr_smema = &smema[0];
// float *[8][128] type(ptr_smema)
// auto *ptr_smemb = &smemb[0];
float4 Av1[2], Av2[2], Bv1[2], Bv2[2], Cv[16], Cres[16];
// Av[2]正好8个值,Av1[0],Av1[1]用于预取策略的交换。 Cv[16]共 16*4= 8*8个值。
float4 pref_Av, pref_Bv; //从global memory中预取的值
// 不直接使用 A,B 大概是因为
// 预取的时候用过Alia防止最后一次循环预取越界的情况。
memset(Cres, 0, sizeof(Cres));
// 循环之前先来次预取, 分别是 global memory的预取,和share memory的预取
vload(pref_Av, &(A(row_a, col_a)));
vload(pref_Bv, &(B(row_b, col_b)));
int buffer_switch = 0; // two sharememoy buffer switch
vstore(&(smemb[0][row_b][col_b]), pref_Bv);
smema[0][col_a][row_a] = pref_Av.x;
smema[0][col_a + 1][row_a] = pref_Av.y;
smema[0][col_a + 2][row_a] = pref_Av.z;
smema[0][col_a + 3][row_a] = pref_Av.w;
// 写入到sharememoy后 进行sync
__syncthreads();
//读取 Atile [row_c,k] k(0-7), 因为转置,Atil的row是连续的
//读取 Btile [k,col_c] col是循环分发的,分配到的 colc 0-3,64-67
vload(Av1[0], &smema[0][0][row_c]);
vload(Av2[0], &smema[0][0][row_c + 4]);
vload(Bv1[0], &smemb[0][0][col_c]);
vload(Bv2[0], &smemb[0][0][col_c + 64]);
for (int global_k = KS; global_k < K; global_k += KS)
{
// global_k起始值是 KS, 因为循环之前已经预取了一次,为避免global
// access越界,这里只到global_k<K,也可以global_k<=K,然后 取余 内存值
// global_k%K 把其限制在 [0,K)里
//但这样会计算 k_iteration-1次,还有一次剩余的需要计算
// main loop 在 K上循环,每次迭代 KS
//读取下一次计算的global memory
//重新定位 ptrA,ptrB等
// 这么大块的语句会不会分支跳转??
A += KS;
// global_K为上移动
B += (KS * ldb);
vload(pref_Av, &(A(row_a, col_a)));
// global memory 的读取实际上一个tx 一次,[128,8] [8,128]
// (整个Block的tx映射到 [128,8] [8,128])
vload(pref_Bv, &(B(row_b, col_b)));
// 最后一次 迭代就会越界, 这里条件判断还是用 @p
// 比较靠谱,不如直接ptx汇编,要不类似 cutlass那样inline ptx gloabal load
int reg_switch = 0;
#pragma unroll
for (int inner_k_count = 0; inner_k_count < KS; inner_k_count++)
{
int next_inner_k_count = (inner_k_count + 1) & 7; //(1...7;1)取余
// prefecth data from smem to register for nex iter compute
int next_reg = reg_switch ^ 1;
vload(Av1[next_reg], &smema[buffer_switch][next_inner_k_count][row_c]);
vload(Av2[next_reg],
&smema[buffer_switch][next_inner_k_count][row_c + 4]);
vload(Bv1[next_reg], &smemb[buffer_switch][next_inner_k_count][col_c]);
vload(Bv2[next_reg],
&smemb[buffer_switch][next_inner_k_count][col_c + 64]);
// next_inner_k_count& 1用以切换预取的register
//行主序,一行是连续的则 索引col是连续的
vscal_fma(Cres[0], Bv1[reg_switch], Av1[reg_switch].x);
vscal_fma(Cres[1], Bv1[reg_switch],
Av1[reg_switch].y); // [rowc+1,colc0-colc3]
vscal_fma(Cres[2], Bv1[reg_switch],
Av1[reg_switch].z); // [rowc+2,colc0-colc3]
vscal_fma(Cres[3], Bv1[reg_switch],
Av1[reg_switch].w); // [rowc+3,colc0-colc3]
vscal_fma(Cres[4], Bv1[reg_switch], Av2[reg_switch].x);
// [row_c+4,col_c0],[row_c,colc_1],[row_c,colc_2],[row_c,colc_3]
vscal_fma(Cres[5], Bv1[reg_switch],
Av2[reg_switch].y); // [rowc+5,colc0-colc3]
vscal_fma(Cres[6], Bv1[reg_switch],
Av2[reg_switch].z); // [rowc+6,colc0-colc3]
vscal_fma(Cres[7], Bv1[reg_switch],
Av2[reg_switch].w); // [rowc+7,colc0-colc3]
vscal_fma(Cres[8], Bv2[reg_switch], Av1[reg_switch].x);
// [row_c,col_c64],[row_c,colc_65],[row_c,colc_66],[row_c,colc_67]
vscal_fma(Cres[9], Bv2[reg_switch],
Av1[reg_switch].y); // [rowc+1,colc64-colc67]
vscal_fma(Cres[10], Bv2[reg_switch],
Av1[reg_switch].z); // [rowc+2,colc64-colc67]
vscal_fma(Cres[11], Bv2[reg_switch],
Av1[reg_switch].w); // [rowc+3,colc64-colc67]
vscal_fma(
Cres[12], Bv2[reg_switch],
Av2[reg_switch]
.x); // [row_c+4,col_c64],[row_c,colc_65],[row_c,colc_66],[row_c,colc_67]
vscal_fma(Cres[13], Bv2[reg_switch],
Av2[reg_switch].y); // [rowc+5,colc64-colc67]
vscal_fma(Cres[14], Bv2[reg_switch],
Av2[reg_switch].z); // [rowc+6,colc64-colc67]
vscal_fma(Cres[15], Bv2[reg_switch],
Av2[reg_switch].w); // [rowc+7,colc64--colc67]
reg_switch ^= 1;
}
buffer_switch ^= 1;
// two sharememoy buffer switch
// store memoy in buffer
vstore(&(smemb[buffer_switch][row_b][col_b]), pref_Bv);
smema[buffer_switch][col_a][row_a] = pref_Av.x;
smema[buffer_switch][col_a + 1][row_a] = pref_Av.y;
smema[buffer_switch][col_a + 2][row_a] = pref_Av.z;
smema[buffer_switch][col_a + 3][row_a] = pref_Av.w;
__syncthreads();
//从 sharememoy 读值
vload(Av1[0], &smema[buffer_switch][0][row_c]);
vload(Av2[0], &smema[buffer_switch][0][row_c + 4]);
vload(Bv1[0], &smemb[buffer_switch][0][col_c]);
vload(Bv2[0], &smemb[buffer_switch][0][col_c + 64]);
// 为下一次子循环预取值
}
// 这个版本去掉了分支,然后手动加一次计算子循环,性能提升了10%
int reg_switch = 0;
#pragma unroll
for (int inner_k_count = 0; inner_k_count < KS; inner_k_count++)
{
int next_inner_k_count = (inner_k_count + 1) & 7; //(1...7;1)取余
// prefecth data from smem to register for nex iter compute
int next_reg = reg_switch ^ 1;
vload(Av1[next_reg], &smema[buffer_switch][next_inner_k_count][row_c]);
vload(Av2[next_reg],
&smema[buffer_switch][next_inner_k_count][row_c + 4]);
vload(Bv1[next_reg], &smemb[buffer_switch][next_inner_k_count][col_c]);
vload(Bv2[next_reg],
&smemb[buffer_switch][next_inner_k_count][col_c + 64]);
// next_inner_k_count& 1用以切换预取的register
//行主序,一行是连续的则 索引col是连续的
vscal_fma(Cres[0], Bv1[reg_switch], Av1[reg_switch].x);
vscal_fma(Cres[1], Bv1[reg_switch],
Av1[reg_switch].y); // [rowc+1,colc0-colc3]
vscal_fma(Cres[2], Bv1[reg_switch],
Av1[reg_switch].z); // [rowc+2,colc0-colc3]
vscal_fma(Cres[3], Bv1[reg_switch],
Av1[reg_switch].w); // [rowc+3,colc0-colc3]
vscal_fma(Cres[4], Bv1[reg_switch], Av2[reg_switch].x);
// [row_c+4,col_c0],[row_c,colc_1],[row_c,colc_2],[row_c,colc_3]
vscal_fma(Cres[5], Bv1[reg_switch],
Av2[reg_switch].y); // [rowc+5,colc0-colc3]
vscal_fma(Cres[6], Bv1[reg_switch],
Av2[reg_switch].z); // [rowc+6,colc0-colc3]
vscal_fma(Cres[7], Bv1[reg_switch],
Av2[reg_switch].w); // [rowc+7,colc0-colc3]
vscal_fma(Cres[8], Bv2[reg_switch], Av1[reg_switch].x);
// [row_c,col_c64],[row_c,colc_65],[row_c,colc_66],[row_c,colc_67]
vscal_fma(Cres[9], Bv2[reg_switch],
Av1[reg_switch].y); // [rowc+1,colc64-colc67]
vscal_fma(Cres[10], Bv2[reg_switch],
Av1[reg_switch].z); // [rowc+2,colc64-colc67]
vscal_fma(Cres[11], Bv2[reg_switch],
Av1[reg_switch].w); // [rowc+3,colc64-colc67]
vscal_fma(
Cres[12], Bv2[reg_switch],
Av2[reg_switch]
.x); // [row_c+4,col_c64],[row_c,colc_65],[row_c,colc_66],[row_c,colc_67]
vscal_fma(Cres[13], Bv2[reg_switch],
Av2[reg_switch].y); // [rowc+5,colc64-colc67]
vscal_fma(Cres[14], Bv2[reg_switch],
Av2[reg_switch].z); // [rowc+6,colc64-colc67]
vscal_fma(Cres[15], Bv2[reg_switch],
Av2[reg_switch].w); // [rowc+7,colc64--colc67]
reg_switch ^= 1;
}
// 上面在global_k上的主循环实际上少迭代一次,因为预取相关的问题,对全局内存不能越界
// load Ctile and accumulate the Cres
vload(Cv[0], &C(row_c, col_c));
vload(Cv[1], &C(row_c + 1, col_c));
vload(Cv[2], &C(row_c + 2, col_c));
vload(Cv[3], &C(row_c + 3, col_c));
vload(Cv[4], &C(row_c + 4, col_c));
vload(Cv[5], &C(row_c + 5, col_c));
vload(Cv[6], &C(row_c + 6, col_c));
vload(Cv[7], &C(row_c + 7, col_c));
vload(Cv[8], &C(row_c, col_c + 64));
vload(Cv[9], &C(row_c + 1, col_c + 64));
vload(Cv[10], &C(row_c + 2, col_c + 64));
vload(Cv[11], &C(row_c + 3, col_c + 64));
vload(Cv[12], &C(row_c + 4, col_c + 64));
vload(Cv[13], &C(row_c + 5, col_c + 64));
vload(Cv[14], &C(row_c + 6, col_c + 64));
vload(Cv[15], &C(row_c + 7, col_c + 64));
simd_axpby(Cres[0], alpha, Cres[0], beta, Cv[0]);
simd_axpby(Cres[1], alpha, Cres[1], beta, Cv[1]);
simd_axpby(Cres[2], alpha, Cres[2], beta, Cv[2]);
simd_axpby(Cres[3], alpha, Cres[3], beta, Cv[3]);
simd_axpby(Cres[4], alpha, Cres[4], beta, Cv[4]);
simd_axpby(Cres[5], alpha, Cres[5], beta, Cv[5]);
simd_axpby(Cres[6], alpha, Cres[6], beta, Cv[6]);
simd_axpby(Cres[7], alpha, Cres[7], beta, Cv[7]);
simd_axpby(Cres[8], alpha, Cres[8], beta, Cv[8]);
simd_axpby(Cres[9], alpha, Cres[9], beta, Cv[9]);
simd_axpby(Cres[10], alpha, Cres[10], beta, Cv[10]);
simd_axpby(Cres[11], alpha, Cres[11], beta, Cv[11]);
simd_axpby(Cres[12], alpha, Cres[12], beta, Cv[12]);
simd_axpby(Cres[13], alpha, Cres[13], beta, Cv[13]);
simd_axpby(Cres[14], alpha, Cres[14], beta, Cv[14]);
simd_axpby(Cres[15], alpha, Cres[15], beta, Cv[15]);
vstore(&C(row_c, col_c), Cres[0]);
vstore(&C(row_c + 1, col_c), Cres[1]);
vstore(&C(row_c + 2, col_c), Cres[2]);
vstore(&C(row_c + 3, col_c), Cres[3]);
vstore(&C(row_c + 4, col_c), Cres[4]);
vstore(&C(row_c + 5, col_c), Cres[5]);
vstore(&C(row_c + 6, col_c), Cres[6]);
vstore(&C(row_c + 7, col_c), Cres[7]);
vstore(&C(row_c, col_c + 64), Cres[8]);
vstore(&C(row_c + 1, col_c + 64), Cres[9]);
vstore(&C(row_c + 2, col_c + 64), Cres[10]);
vstore(&C(row_c + 3, col_c + 64), Cres[11]);
vstore(&C(row_c + 4, col_c + 64), Cres[12]);
vstore(&C(row_c + 5, col_c + 64), Cres[13]);
vstore(&C(row_c + 6, col_c + 64), Cres[14]);
vstore(&C(row_c + 7, col_c + 64), Cres[15]);
}
// 3750.55 3591.11 3529.28 3460.77
template <typename T>
bool verify_res(size_t m, size_t n, const thrust::device_vector<T> &ref_data,
const thrust::device_vector<T> &res_data,
T abs_error = T(1e-2))
{
thrust::host_vector<T> href_data = ref_data;
thrust::host_vector<T> hres_data = res_data;
T max_error = std::numeric_limits<T>::lowest();
int num_errors = 0;
for (size_t i = 0; i < m; i++)
{
for (size_t j = 0; j < n; j++)
{
auto tmp_error = std::abs(hres_data[i * n + j] - href_data[i * n + j]);
// std::cout<<tmp_error<<"\n";
if (tmp_error > abs_error)
{
num_errors++;
max_error = max_error < tmp_error ? tmp_error : max_error;
}
}
}
std::cout << "num_error: " << num_errors << " max error= " << max_error
<< " \n";
return num_errors == 0;
}
template <typename T>
void host_gemm(int M, int N, int K, std::vector<T> &A, std::vector<T> &B,
std::vector<T> &C, T alpha, T beta)
{
for (int m = 0; m < M; m++)
{
for (int n = 0; n < N; n++)
{
T accum = 0;
for (int k = 0; k < K; k++)
{
accum += A[m * K + k] * B[k * N + n];
}
C[m * N + n] = alpha * accum + beta * C[m * N + n];
}
}
}
// print an M-by-N array
template <typename T>
void print(size_t m, size_t n, thrust::device_vector<T> &d_data)
{
thrust::host_vector<T> h_data = d_data;
for (size_t i = 0; i < m; i++)
{
for (size_t j = 0; j < n; j++)
std::cout << std::setw(1) << h_data[i * n + j] << " ";
std::cout << "\n";
}
}
int main(int argc, char **argv)
{
const int M = 6144;
const int N = 6144;
const int K = 6144;
constexpr int Ms = 128;
constexpr int Ns = 128;
constexpr int Ks = 8;
using Element = float;
std::vector<Element> hA(M * K);
std::vector<Element> hB(K * N);
std::vector<Element> hC(M * N);
std::random_device rd; // 将用于获得随机数引擎的种子
std::mt19937 gen(rd()); // 以 rd() 播种的标准 mersenne_twister_engine
std::uniform_real_distribution<Element> dis(1, 10);
std::generate(hA.begin(), hA.end(), [&rd, &gen, &dis]()
{ return dis(gen); });
std::generate(hB.begin(), hB.end(), [&rd, &gen, &dis]()
{ return dis(gen); });
thrust::device_vector<Element> dA = hA;
thrust::device_vector<Element> dB = hB;
thrust::device_vector<Element> dC(M * N);
thrust::device_vector<Element> drefC(M * N);
Element *dA_ptr = thrust::raw_pointer_cast(dA.data());
Element *dB_ptr = thrust::raw_pointer_cast(dB.data());
Element *dC_ptr = thrust::raw_pointer_cast(dC.data());
Element *dCref_ptr = thrust::raw_pointer_cast(drefC.data());
cublasHandle_t handle;
cublasSafeCall(cublasCreate(&handle));
float alpha = 1.;
float beta = 0.;
cublasSafeCall(cublasSgemm_v2(handle, CUBLAS_OP_N, CUBLAS_OP_N, N, M, K,
&alpha, dB_ptr, N, dA_ptr, K, &beta, dCref_ptr,
N));
const dim3 block(256);
const dim3 grid((M + 127) / 128, (N + 127) / 128);
gemm_kernel<<<grid, block>>>(M, N, K, alpha, dA_ptr, dB_ptr, beta, dC_ptr);
gemm_kernel<<<grid, block>>>(M, N, K, alpha, dA_ptr, dB_ptr, beta, dC_ptr);
std::cout << cudaGetErrorString(cudaGetLastError()) << "\n";
verify_res(M, N, dC, drefC);
// host_gemm(M, N, K, hA, hB, hC, alpha, beta);
// verify_res(M, N, thrust::device_vector<Element>(hC), drefC);
// std::cout << "hc verify dc \n";
// verify_res(M, N, thrust::device_vector<Element>(hC), dC);
// print(M, N, dC);
// print(M, N, drefC);
}
//
//
// nvcc -arch=sm_75 -O3 ./gemm.cu -o gemmtest -lcublas