做了一个简单的cuda项目,故记录下来。
头文件和宏定义 1 2 3 4 5 6 7 8 9 10 11 12 13 #include <algorithm> #include <cmath> #include <vector> #include <cuda_fp16.h> #include <cuda_runtime.h> #include <iostream> #include <cassert> #include "../tester/utils.h" #define BLOCK_SIZE 256 constexpr int Br = 16 ; constexpr int Bc = 16 ;
1. Warp 级归约操作 Warp Reduce Sum 在 32 个线程内快速求和(不需要 Shared Memory)
1 2 3 4 5 6 7 8 9 10 11 12 template <typename T>__device__ __forceinline__ T warpReduceSum (T val) { for (int offset = warpSize / 2 ; offset > 0 ; offset >>= 1 ) { val += __shfl_down_sync(0xffffffff , val, offset); } return val; }
Warp Reduce Max 辅助函数:Warp 内求最大值
1 2 3 4 5 6 7 8 template <typename T>__device__ __forceinline__ T warpReduceMax (T val) { for (int offset = 16 ; offset > 0 ; offset >>= 1 ) { T temp = __shfl_down_sync(0xffffffff , val, offset); if (temp > val) val = temp; } return val; }
2. Block 级归约操作 blockDim.x定义在dim3 block(BLOCK_SIZE);
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 template <typename T>__device__ __forceinline__ T blockReduceSum (T val) { if (blockDim.x <= warpSize) { return warpReduceSum (val); } static __shared__ T shared[32 ]; int lane = threadIdx.x % warpSize; int wid = threadIdx.x / warpSize; val = warpReduceSum (val); if (lane == 0 ) { shared[wid] = val; } __syncthreads(); val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0 ; if (wid == 0 ) { val = warpReduceSum (val); } return val; }
3. 矩阵迹的计算 迹的核函数 注意核函数均定义为void
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 template <typename T>__global__ void traceKernel (const T* __restrict__ input, size_t rows, size_t cols, size_t n, T* out) { size_t idx = blockIdx.x * blockDim.x + threadIdx.x; size_t stride = blockDim.x * gridDim.x; T local_sum = 0 ; for (size_t i = idx; i < n; i += stride) { local_sum += input[i * cols + i]; } local_sum = blockReduceSum (local_sum); if (threadIdx.x == 0 ) { atomicAdd (out, local_sum); } }
Host 端迹的计算函数 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 template <typename T>T trace (const std::vector<T>& h_input, size_t rows, size_t cols) { if (rows == 0 || cols == 0 ) { return T (0 ); } size_t n = std::min (rows, cols); size_t total_elems = rows * cols; T* d_input = nullptr ; T* d_out = nullptr ; RUNTIME_CHECK (cudaMalloc (&d_input, total_elems * sizeof (T))); RUNTIME_CHECK (cudaMalloc (&d_out, sizeof (T))); RUNTIME_CHECK (cudaMemcpy (d_input, h_input.data (), total_elems * sizeof (T), cudaMemcpyHostToDevice)); RUNTIME_CHECK (cudaMemset (d_out, 0 , sizeof (T))); dim3 block (BLOCK_SIZE) ; dim3 grid ((n + BLOCK_SIZE - 1 ) / BLOCK_SIZE) ; traceKernel<T><<<grid, block>>>(d_input, rows, cols, n, d_out); RUNTIME_CHECK (cudaGetLastError ()); RUNTIME_CHECK (cudaDeviceSynchronize ()); T h_out = T (0 ); RUNTIME_CHECK (cudaMemcpy (&h_out, d_out, sizeof (T), cudaMemcpyDeviceToHost)); RUNTIME_CHECK (cudaFree (d_input)); RUNTIME_CHECK (cudaFree (d_out)); return h_out; }
4. 注意力机制 - 类型转换 float 类型转换 1 2 3 4 5 6 7 8 9 10 11 12 template <typename T>__device__ __forceinline__ float to_float (T v) ;template <>__device__ __forceinline__ float to_float <float >(float v) { return v; } template <>__device__ __forceinline__ float to_float <half>(half v) { return __half2float(v); }
反向类型转换 1 2 3 4 5 6 7 8 9 10 11 12 13 template <typename T>__device__ __forceinline__ T from_float (float v) ;template <> __device__ __forceinline__ float from_float <float >(float v) { return v; } template <>__device__ __forceinline__ half from_float <half>(float v) { return __float2half(v); }
forceinline
强制内联: • 减少函数调用开销 • 对 warp-level 操作重要 ifelse逻辑会runtime branching,性能差
5. 朴素注意力实现 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 template <typename T>__global__ void native_attention_kernel (const T* q, const T* k, const T* v, T* o, int batch_size, int target_seq_len, int src_seq_len, int query_heads, int kv_heads, int head_dim, bool is_causal) { size_t idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t o_elems = batch_size * target_seq_len * query_heads * head_dim; if (idx >= o_elems) { return ; } int d = static_cast <int >(idx % head_dim); size_t tmp = idx / head_dim; int qh = static_cast <int >(tmp % query_heads); tmp /= query_heads; int t = static_cast <int >(tmp % target_seq_len); int b = static_cast <int >(tmp / target_seq_len); int kv_h = (qh * kv_heads) / query_heads; const float scale = 1.0f / sqrtf (static_cast <float >(head_dim)); float max_score = -INFINITY; for (int sk = 0 ; sk < src_seq_len; ++sk) { if (is_causal && sk > t) { continue ; } const T* q_ptr = q + (((b * target_seq_len + t) * query_heads + qh) * head_dim); const T* k_ptr = k + (((b * src_seq_len + sk) * kv_heads + kv_h) * head_dim); float dot = 0.0f ; for (int i = 0 ; i < head_dim; ++i) { dot += to_float (q_ptr[i]) * to_float (k_ptr[i]); } float score = dot * scale; if (score > max_score) { max_score = score; } } float denom = 0.0f ; for (int sk = 0 ; sk < src_seq_len; ++sk) { if (is_causal && sk > t) { continue ; } const T* q_ptr = q + (((b * target_seq_len + t) * query_heads + qh) * head_dim); const T* k_ptr = k + (((b * src_seq_len + sk) * kv_heads + kv_h) * head_dim); float dot = 0.0f ; for (int i = 0 ; i < head_dim; ++i) { dot += to_float (q_ptr[i]) * to_float (k_ptr[i]); } float score = dot * scale; denom += expf (score - max_score); } if (denom == 0.0f ) { o[idx] = from_float <T>(0.0f ); return ; } float out_val = 0.0f ; for (int sk = 0 ; sk < src_seq_len; ++sk) { if (is_causal && sk > t) { continue ; } const T* q_ptr = q + (((b * target_seq_len + t) * query_heads + qh) * head_dim); const T* k_ptr = k + (((b * src_seq_len + sk) * kv_heads + kv_h) * head_dim); const T* v_ptr = v + (((b * src_seq_len + sk) * kv_heads + kv_h) * head_dim); float dot = 0.0f ; for (int i = 0 ; i < head_dim; ++i) { dot += to_float (q_ptr[i]) * to_float (k_ptr[i]); } float score = dot * scale; float w = expf (score - max_score) / denom; out_val += w * to_float (v_ptr[d]); } o[idx] = from_float <T>(out_val); }
6. Flash Attention V1 - 优化版本 核心改进说明 Flash Attention V2升级到支持分块(Tiling)计算 :
2D线程块 :从1D改为(Bc, Br),充分利用硬件
Tile操作 :分别载入和计算Q、K、V的tile,减少全局内存访问
Bank Conflict避免 :使用smem_stride参数在共享内存中填充,避免冲突
Online Softmax :在计算过程中动态更新max和denominator,数值稳定
Warp级规约 :使用shuffle指令进行高效的跨线程规约
优化后的核函数 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 template <typename T>__global__ void flash_attention_v1_kernel (const T* __restrict__ Q, const T* __restrict__ K, const T* __restrict__ V, T* __restrict__ O, int batch_size, int target_seq_len, int src_seq_len, int query_heads, int kv_heads, int head_dim, int smem_stride, bool is_causal, float scale) { int tx = threadIdx.x; int ty = threadIdx.y; int batch_idx = blockIdx.z; int head_idx = blockIdx.y; int q_block_idx = blockIdx.x; int q_start_idx = q_block_idx * Br; int q_len_local = min (Br, target_seq_len - q_start_idx); int kv_head_idx = (head_idx * kv_heads) / query_heads; extern __shared__ float smem[]; float * s_Q = smem; float * s_K = s_Q + Br * smem_stride; float * s_V = s_K + Bc * smem_stride; float * s_O = s_V + Bc * smem_stride; float * s_m = s_O + Br * smem_stride; float * s_l = s_m + Br; int tid = threadIdx.y * blockDim.x + threadIdx.x; int total_threads = blockDim.x * blockDim.y; for (int i = tid; i < Br * head_dim; i += total_threads) { int r = i / head_dim; int c = i % head_dim; int global_q = q_start_idx + r; if (r < q_len_local && global_q < target_seq_len) { size_t q_index = ((static_cast <size_t >(batch_idx) * target_seq_len + global_q) * query_heads + head_idx) * head_dim + c; s_Q[r * smem_stride + c] = to_float (Q[q_index]); } else { s_Q[r * smem_stride + c] = 0.0f ; } s_O[r * smem_stride + c] = 0.0f ; } if (tx == 0 && ty < Br) { s_m[ty] = -1e20 f; s_l[ty] = 0.0f ; } __syncthreads(); for (int j_base = 0 ; j_base < src_seq_len; j_base += Bc) { int kv_len_local = min (Bc, src_seq_len - j_base); for (int i = tid; i < Bc * head_dim; i += total_threads) { int r = i / head_dim; int c = i % head_dim; int global_k = j_base + r; if (r < kv_len_local && global_k < src_seq_len) { size_t k_index = ((static_cast <size_t >(batch_idx) * src_seq_len + global_k) * kv_heads + kv_head_idx) * head_dim + c; s_K[r * smem_stride + c] = to_float (K[k_index]); s_V[r * smem_stride + c] = to_float (V[k_index]); } else { s_K[r * smem_stride + c] = 0.0f ; s_V[r * smem_stride + c] = 0.0f ; } } __syncthreads(); if (ty < q_len_local) { float score = 0.0f ; bool valid_k = (tx < kv_len_local); int global_q_idx = q_start_idx + ty; int global_k_idx = j_base + tx; if (is_causal && global_k_idx > global_q_idx) { valid_k = false ; } if (valid_k) { for (int d = 0 ; d < head_dim; ++d) { score += s_Q[ty * smem_stride + d] * s_K[tx * smem_stride + d]; } score *= scale; } else { score = -INFINITY; } unsigned mask = __activemask(); float m_local = score; #pragma unroll for (int offset = 8 ; offset > 0 ; offset /= 2 ) { m_local = fmaxf (m_local, __shfl_xor_sync(mask, m_local, offset)); } float p = (score == -INFINITY) ? 0.0f : expf (score - m_local); float l_local = p; #pragma unroll for (int offset = 8 ; offset > 0 ; offset /= 2 ) { l_local += __shfl_xor_sync(mask, l_local, offset); } float m_prev = s_m[ty]; float l_prev = s_l[ty]; float m_new = fmaxf (m_prev, m_local); float scale_prev = expf (m_prev - m_new); float scale_curr = expf (m_local - m_new); float l_new = l_prev * scale_prev + l_local * scale_curr; if (tx == 0 ) { s_m[ty] = m_new; s_l[ty] = l_new; } for (int d = 0 ; d < head_dim; ++d) { float v_val = valid_k ? s_V[tx * smem_stride + d] : 0.0f ; float pd = p * v_val; #pragma unroll for (int offset = 8 ; offset > 0 ; offset /= 2 ) { pd += __shfl_xor_sync(mask, pd, offset); } if (tx == 0 ) { float o_val = s_O[ty * smem_stride + d]; s_O[ty * smem_stride + d] = o_val * scale_prev + pd * scale_curr; } } } __syncthreads(); } if (ty < q_len_local) { float denom = s_l[ty]; float inv_l = (denom > 0.0f ) ? (1.0f / denom) : 0.0f ; for (int d = tx; d < head_dim; d += Bc) { float val = s_O[ty * smem_stride + d] * inv_l; int global_q = q_start_idx + ty; size_t o_index = ((static_cast <size_t >(batch_idx) * target_seq_len + global_q) * query_heads + head_idx) * head_dim + d; O[o_index] = from_float <T>(val); } } }
Host 端 Flash Attention 函数 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 template <typename T>void flashAttention (const std::vector<T>& h_q, const std::vector<T>& h_k, const std::vector<T>& h_v, std::vector<T>& h_o, int batch_size, int target_seq_len, int src_seq_len, int query_heads, int kv_heads, int head_dim, bool is_causal) { if (batch_size <= 0 || target_seq_len <= 0 || src_seq_len <= 0 || query_heads <= 0 || kv_heads <= 0 || head_dim <= 0 ) { h_o.clear (); return ; } const size_t q_elems = batch_size * target_seq_len * query_heads * head_dim; const size_t k_elems = batch_size * src_seq_len * kv_heads * head_dim; const size_t v_elems = k_elems; const size_t o_elems = q_elems; h_o.resize (o_elems); T* d_q = nullptr ; T* d_k = nullptr ; T* d_v = nullptr ; T* d_o = nullptr ; RUNTIME_CHECK (cudaMalloc (&d_q, q_elems * sizeof (T))); RUNTIME_CHECK (cudaMalloc (&d_k, k_elems * sizeof (T))); RUNTIME_CHECK (cudaMalloc (&d_v, v_elems * sizeof (T))); RUNTIME_CHECK (cudaMalloc (&d_o, o_elems * sizeof (T))); RUNTIME_CHECK (cudaMemcpy (d_q, h_q.data (), q_elems * sizeof (T), cudaMemcpyHostToDevice)); RUNTIME_CHECK (cudaMemcpy (d_k, h_k.data (), k_elems * sizeof (T), cudaMemcpyHostToDevice)); RUNTIME_CHECK (cudaMemcpy (d_v, h_v.data (), v_elems * sizeof (T), cudaMemcpyHostToDevice)); RUNTIME_CHECK (cudaMemset (d_o, 0 , o_elems * sizeof (T))); dim3 block (Bc, Br) ; int grid_x = (target_seq_len + Br - 1 ) / Br; dim3 grid (grid_x, query_heads, batch_size) ; int smem_stride = head_dim + 4 ; size_t shared_bytes = (Br * smem_stride + Bc * smem_stride + Bc * smem_stride + Br * smem_stride + Br + Br) * sizeof (float ); float scale = 1.0f / sqrtf (static_cast <float >(head_dim)); flash_attention_v1_kernel<T><<<grid, block, shared_bytes>>>( d_q, d_k, d_v, d_o, batch_size, target_seq_len, src_seq_len, query_heads, kv_heads, head_dim, smem_stride, is_causal, scale); RUNTIME_CHECK (cudaGetLastError ()); RUNTIME_CHECK (cudaDeviceSynchronize ()); RUNTIME_CHECK (cudaMemcpy (h_o.data (), d_o, o_elems * sizeof (T), cudaMemcpyDeviceToHost)); RUNTIME_CHECK (cudaFree (d_q)); RUNTIME_CHECK (cudaFree (d_k)); RUNTIME_CHECK (cudaFree (d_v)); RUNTIME_CHECK (cudaFree (d_o)); }
更新日志 Flash Attention 优化升级 主要改进 :
Tiling 策略
Query 按 Br×head_dim 分块(Br=16)
Key/Value 按 Bc×head_dim 分块(Bc=16)
减少全局内存访问,增加局部数据重用
线程块布局
从 1D 改为 2D:(Bc, Br) 形状
x 维处理 K/V 的列(head_dim)
y 维处理 Q 的行(Br)
Bank Conflict 消除
共享内存使用 smem_stride = head_dim + 4
避免多个线程同时访问同一 bank
Online Softmax 实现
每个 tile 进行一次 softmax 更新
使用归一化公式:$m_{new} = \max(m_{prev}, m_{local})$
$l_{new} = l_{prev} \cdot e^{m_{prev} - m_{new}} + l_{local} \cdot e^{m_{local} - m_{new}}$
Warp 级规约
使用 __shfl_xor_sync 替代 shared memory 规约
更高效的跨 lane 通信
Grid 配置
Grid: (grid_x, query_heads, batch_size)
blockIdx.x 对应 Q block,blockIdx.y 对应 head,blockIdx.z 对应 batch
7. 显式模板实例化 1 2 3 4 5 6 7 8 9 10 template int trace <int >(const std::vector<int >&, size_t , size_t );template float trace <float >(const std::vector<float >&, size_t , size_t );template void flashAttention <float >(const std::vector<float >&, const std::vector<float >&, const std::vector<float >&, std::vector<float >&, int , int , int , int , int , int , bool ); template void flashAttention <half>(const std::vector<half>&, const std::vector<half>&, const std::vector<half>&, std::vector<half>&, int , int , int , int , int , int , bool );
8. 极致性能优化:与第一对比 在参考了其他优秀的开源实现(如 forlearn/Learning-CUDA/src/kernels.cu)后,我发现我的 V1 版本虽然实现了基本的分块和 Shared Memory 优化,但在工程细节和极致性能压榨上还有很大的提升空间。以下是对该高性能版本中 5 个核心优化点的深度剖析,通过“我的写法”与“人家的写法”的直观对比,总结其性能差异。
8.1 内存分配与释放的极致优化:合并与异步 在原版实现中,我们为 Q、K、V、O 分别调用了四次 cudaMalloc 和 cudaFree。由于 GPU 内存分配是同步且开销极大的操作,这会严重拖慢整体执行速度。
我的写法(多次同步分配与释放) :
1 2 3 4 5 6 7 8 9 10 T* d_q, *d_k, *d_v, *d_o; RUNTIME_CHECK (cudaMalloc (&d_q, q_elems * sizeof (T)));RUNTIME_CHECK (cudaMalloc (&d_k, k_elems * sizeof (T)));RUNTIME_CHECK (cudaMalloc (&d_v, v_elems * sizeof (T)));RUNTIME_CHECK (cudaMalloc (&d_o, o_elems * sizeof (T)));RUNTIME_CHECK (cudaFree (d_q));RUNTIME_CHECK (cudaFree (d_k));RUNTIME_CHECK (cudaFree (d_v));RUNTIME_CHECK (cudaFree (d_o));
人家的写法(单次异步分配,切片使用) : 计算出总共需要的内存大小,一次性分配一整块连续的 Device 内存,然后通过指针偏移(切片)给 Q、K、V、O 使用。同时使用 cudaMallocAsync 和 cudaFreeAsync,并绑定到具有高优先级的独立非阻塞 Stream 上,减少与默认 Stream 的同步开销。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 cudaStreamCreateWithPriority (&stream2, cudaStreamNonBlocking, -1 );const size_t total_bytes = size_bytes_q + size_bytes_k + size_bytes_v + size_bytes_o;T* d_all = nullptr ; RUNTIME_CHECK (cudaMallocAsync (&d_all, total_bytes, stream2));T *d_q = d_all; T *d_k = d_q + h_q.size (); T *d_v = d_k + h_k.size (); T *d_o = d_v + h_v.size (); RUNTIME_CHECK (cudaMemcpyAsync (d_q, h_q.data (), size_bytes_q, cudaMemcpyHostToDevice, stream2));RUNTIME_CHECK (cudaFreeAsync (d_all, stream2));
8.2 算法层面的 I/O 优化:仅加载有效数据 在计算矩阵的迹(Trace)时,原版代码将整个 $N \times N$ 的矩阵拷贝到了 GPU,然后在 Kernel 中通过 i * cols + i 提取对角线元素。这导致了 $O(N^2)$ 的数据传输,而实际参与计算的只有 $O(N)$ 的数据。
我的写法(全量拷贝,跨步访问) :
1 2 3 4 5 6 7 RUNTIME_CHECK (cudaMemcpy (d_input, h_input.data (), total_elems * sizeof (T), cudaMemcpyHostToDevice));for (size_t i = idx; i < n; i += stride) { local_sum += input[i * cols + i]; }
人家的写法(按需拷贝,连续访问) : 在 Host 端预先提取对角线元素,仅将有效数据传输至 Device 端。这不仅大幅降低了 PCIe 带宽压力,还使得 Kernel 中的内存访问变成了完全连续的合并访问(Coalesced Memory Access)。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 template <typename T>std::vector<T> extract_diag (const std::vector<T> & h_input, size_t rows, size_t cols) { size_t n = std::min (rows, cols); std::vector<T> diag (n) ; for (size_t i = 0 ; i < n; ++i){ diag[i] = h_input[i * cols + i]; } return diag; } template <typename T>__global__ void trace_calc (T* d_trace, const T* d_diag, size_t n) { for (size_t i = idx; i < n; i += stride){ sum += d_diag[i]; } }
8.3 循环展开与分支预测优化 在 GPU 编程中,分支(Branching)和循环控制开销会破坏指令流水线。
我的写法(常规循环与内层条件判断) :
1 2 3 4 5 6 7 8 9 for (int offset = warpSize / 2 ; offset > 0 ; offset >>= 1 ) { val += __shfl_down_sync(0xffffffff , val, offset); } if (is_causal && global_k_idx > global_q_idx) { valid_k = false ; }
人家的写法(强制展开与分支前置) : 对已知迭代次数的短循环进行强制展开; #pragam unroll 指令告诉编译器将循环完全展开,消除循环控制开销和分支预测的影响。 编译后直接变成: val += __shfl_down_sync(…, 16); val += __shfl_down_sync(…, 8); val += __shfl_down_sync(…, 4); val += __shfl_down_sync(…, 2); val += __shfl_down_sync(…, 1);
将复杂的 Causal Mask 判断逻辑提前计算并转换为布尔标志,避免在内层循环中反复进行复杂的条件判断。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 template <typename T>__device__ T warp_reduce_sum (T val) {#pragma unroll for (int offset = 16 ; offset > 0 ; offset >>= 1 ){ val += __shfl_down_sync(0xffffffff , val, offset); } return val; } bool is_compute = true ; if (is_causal) { }
8.4 Shared Memory 资源复用 Shared Memory(SMEM)是 SM 中极其宝贵且有限的资源。在 Flash Attention 中,我们需要存储 $S = Q \times K^T$ 的结果,随后计算 $P = \text{softmax}(S)$。
我的写法(独立分配) : 为不同阶段的变量分配独立的 Shared Memory,如果中间矩阵较多,极易导致 SMEM 耗尽,降低 Thread Block 的 Occupancy(占用率)。
1 2 3 4 float * s_Q = smem; float * s_K = s_Q + Br * smem_stride; float * s_V = s_K + Bc * smem_stride;
人家的写法(生命周期错开的变量直接复用) : 由于 $S$ 矩阵在计算出 $P$ 之后就不再被需要,直接让 $P$ 覆盖 $S$ 的内存空间。通过指针复用,节省了一半的中间矩阵 SMEM 占用。
1 2 3 4 5 6 7 8 9 extern __shared__ char shared_mem[];char * ptr = shared_mem; double * SP = reinterpret_cast <double *>(ptr); ptr += Br * Bc * sizeof (double ); #define SP_AT(y, x) SP[y * Bc + x]
8.5 关键步骤的精度保护 (Double Precision) 在 Flash Attention 的 Online Softmax 计算中,涉及到指数运算 exp(S - m) 和累加求和。如果使用 float 甚至 half,在序列较长或数值差异较大时,极易发生精度溢出或下溢(Underflow)。
我的写法(全程单精度) :
1 2 3 4 5 const float scale = 1.0f / sqrtf (static_cast <float >(head_dim));float score = 0.0f ;float p = (score == -INFINITY) ? 0.0f : expf (score - m_local);
人家的写法(关键路径双精度护航) : 在计算 $S$ 矩阵、$P$ 矩阵以及缩放因子 scale_factor 时,强制提升至 double 精度进行中间计算,最后再向下转换为目标类型。这在不显著增加计算时间的前提下,极大地保护了数值稳定性。
1 2 3 4 5 6 7 8 9 const double scale_factor = 1.0 / sqrt (double (head_dim));double * SP = reinterpret_cast <double *>(ptr);
8.6 Shared Memory 访存模式优化:K 矩阵转置消除 Bank Conflict 在计算 $S = Q \times K^T$ 时,需要从 Shared Memory 中读取 Q 和 K 的数据。
我的写法(依赖 Padding 缓解冲突) :
1 2 3 4 5 6 float * s_K = s_Q + Br * smem_stride; score += s_Q[ty * smem_stride + d] * s_K[tx * smem_stride + d];
人家的写法(加载时转置,实现连续访存) : 在将 K 矩阵从 Global Memory 加载到 Shared Memory 时,直接将其转置存储为 K_T_sm[head_dim][Bc]。这样在计算点积时,同一 Warp 内的线程(tx 不同)访问的是内存中连续的地址,从根本上消除了 Bank Conflict,达到了极致的访存效率。
1 2 3 4 5 6 7 8 float * K_T_sm = reinterpret_cast <float *>(ptr); #define K_T_sm_AT(y, x) K_T_sm[y * Bc + x]
8.7 跨平台兼容与软硬件协同:软件模拟双精度 (FP32x2) 在某些国产 GPU 平台(如 Iluvatar)或消费级显卡上,硬件原生的双精度(FP64)计算单元可能非常少,导致使用 double 会造成严重的性能瓶颈。
我的写法(无视硬件差异) : 没有考虑不同硬件平台的特性,直接使用标准的 float 或 double,在 FP64 性能孱弱的显卡上会遭遇断崖式掉速。
人家的写法(软件模拟双精度) : 针对特定平台(#ifdef PLATFORM_ILUVATAR),实现了一个 myDouble 类。利用两个 float(_hi 和 _lo)以及 FMA 指令(__fmaf_rn)在软件层面模拟双精度计算。这在保证 Online Softmax 数值稳定性的同时,完全避开了硬件 FP64 性能孱弱的瓶颈,是极其硬核的极致压榨。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 #ifdef PLATFORM_ILUVATAR class myDouble {private : float _hi; float _lo; public : __host__ __device__ myDouble operator *(const float op) const { float p_hi = _hi * op; float p_lo = __fmaf_rn(_hi, op, -p_hi); p_lo += (_lo * op); float final_hi = p_hi + p_lo; float final_lo = p_lo - (final_hi - p_hi); return myDouble (final_hi, final_lo); } }; #endif
8.8 常量计算的查表法 (LUT) 优化 在计算 Attention Score 的缩放因子时,需要用到 1.0 / sqrt(head_dim)。
我的写法(运行时计算) : 在 Kernel 启动前或 Kernel 内部,调用浮点开方和除法指令进行计算。
1 const float scale = 1.0f / sqrtf (static_cast <float >(head_dim));
人家的写法(编译期/设备端查表) : 针对常见的 head_dim(如 32, 64),直接硬编码预计算好的高精度双浮点结果(mylut),用查表法(Lookup Table)替代了昂贵的浮点开方和除法指令。
1 2 3 4 5 6 7 8 9 10 11 __device__ myDouble mylut (int head_dim) { switch (head_dim){ case 32 :return myDouble (0.176776695f , 2.96636886e-10 f); case 64 :return myDouble (0.125f , 0.0f ); default :return myDouble (0.0f , 0.0f ); } } const myDouble scale_factor = mylut (head_dim);
8.9 国产平台适配:C++ 标准与编译器特性的妥协 在将代码移植到不同的国产 GPU 平台(如 Moore 摩尔线程)时,编译器的支持程度往往参差不齐。
我的写法(过度依赖现代 C++ 特性) : 在实现 myexp 函数时,使用了 C++17 的 if constexpr 来进行编译期类型分支判断。这在 NVCC 或较新的编译器上运行良好,但在某些仅支持 C++11 的国产平台编译器(如 Moore 平台的 musa 编译器)上会直接导致编译失败。
1 2 3 4 5 6 7 template <typename T>__device__ T myexp (T x) { if constexpr (std::is_same<T, __half>::value) { } }
人家的写法(回归基础,利用模板特化) : 为了保证最大的跨平台兼容性,放弃了 if constexpr,转而使用最基础的 C++98/11 模板特化(Template Specialization)或函数重载来实现不同类型的分支逻辑。
1 2 3 4 5 6 7 8 9 10 11 12 __device__ __forceinline__ float myexp (float x) { return expf (x); } __device__ __forceinline__ double myexp (double x) { return exp (x); } __device__ __forceinline__ half myexp (half x) { return __float2half(expf (__half2float(x))); }
8.10 国产平台适配:从 CUDA 到 MUSA (Moore) 与 MACA (沐曦) 的迁移指南 在将 CUDA 代码迁移到国产 GPU 平台(如摩尔线程的 MUSA 和沐曦的 MACA)时,除了上述提到的 C++ 标准兼容性问题,最核心的工作是 API 的替换与硬件特性的适配。
1. API 命名空间的无缝替换 国产平台通常提供了与 CUDA 高度兼容的 API 接口,迁移的第一步是进行全局的命名空间替换。
**CUDA 到 MUSA (摩尔线程 Moore)**:
头文件:#include <cuda_fp16.h> $\rightarrow$ #include <musa_fp16.h>
内存管理:cudaMalloc $\rightarrow$ musaMalloc,cudaFree $\rightarrow$ musaFree
数据拷贝:cudaMemcpy $\rightarrow$ musaMemcpy,cudaMemcpyHostToDevice $\rightarrow$ musaMemcpyHostToDevice
错误检查:cudaGetLastError $\rightarrow$ musaGetLastError
设备同步:cudaDeviceSynchronize $\rightarrow$ musaDeviceSynchronize
**CUDA 到 MACA (沐曦)**:
头文件:#include <cuda_fp16.h> $\rightarrow$ #include <common/maca_fp16.h>
内存管理:cudaMalloc $\rightarrow$ mcMalloc,cudaFree $\rightarrow$ mcFree
数据拷贝:cudaMemcpy $\rightarrow$ mcMemcpy,cudaMemcpyHostToDevice $\rightarrow$ mcMemcpyHostToDevice
错误检查:cudaGetLastError $\rightarrow$ mcGetLastError
设备同步:cudaDeviceSynchronize $\rightarrow$ mcDeviceSynchronize
2. 硬件架构参数的微调 不同平台的 SM(Streaming Multiprocessor)架构和资源限制不同,需要针对性地调整 Kernel 的启动参数。
Warp Size 差异 :
NVIDIA (CUDA) 和 Moore (MUSA) 的 Warp Size 通常为 32 。
沐曦 (MACA) 的 Warp Size 可能为 64 (如 mcDeviceProp_t.warpSize 所示)。
适配建议 :在进行 Warp 级规约(如 warp_reduce_sum)时,循环的初始 offset 需要根据平台的实际 Warp Size 进行调整(如从 16 改为 32)。
Shared Memory 限制 :
NVIDIA RTX 5090:每 Block 约 48KB。
Moore:每 Block 可达 192KB。
沐曦:每 Block 约 64KB。
适配建议 :在设计 Tiling 策略(如 Br 和 Bc 的大小)时,需确保分配的 smem_size 不超过目标平台的硬件上限。
3. 异步操作的支持度 在某些国产平台的早期驱动或特定型号上,异步 API(如 cudaMallocAsync、cudaFreeAsync)可能未被完全支持或存在性能问题。
适配建议 :在迁移初期,建议先回退到同步 API(如 mcMalloc、musaMalloc),确保功能正确性后,再逐步尝试引入 Stream 和 Async API 进行性能调优。
8.11 总结 通过对比分析这份高性能代码,我们可以得出编写极致 CUDA 算子的几个核心方法论:
Host 端能做的绝不交给 Device 端 (如提取对角线)。
API 调用的开销不容忽视 (合并 Malloc,使用 Async 和 Stream)。
寄存器 > Shared Memory > Global Memory (数据一旦加载到 SMEM,就要尽可能榨干其复用价值,如 S/P 矩阵复用)。
指令级优化 (#pragma unroll 消除循环开销,位运算替代乘除法)。
在性能与精度之间寻找平衡 (关键路径使用 double 护航)。
Shared Memory 布局的艺术 (通过转置消除 Bank Conflict,远比单纯加 Padding 高效)。
软硬件协同的极致压榨 (在 FP64 孱弱的平台上,用 FP32x2 软件模拟双精度)。
空间换时间 (用查表法 LUT 替代昂贵的数学指令)。
跨平台适配的克制与灵活 (在国产平台上,尽量使用保守的 C++11 标准;熟练掌握 API 替换规则;并根据目标硬件的 Warp Size 和 SMEM 上限动态调整 Kernel 参数)。