算子实例

Ikko Lv4

头文件和宏定义

1
2
3
4
5
6
7
8
#include <algorithm>
#include <cmath>
#include <vector>
#include <cuda_fp16.h>

#include "../tester/utils.h"

#define BLOCK_SIZE 256

1. Warp 级归约操作

Warp Reduce Sum

在 32 个线程内快速求和(不需要 Shared Memory)

1
2
3
4
5
6
7
8
9
10
11
template <typename T>
__device__ __forceinline__ T warpReduceSum(T val) {
// 0xffffffff 表示 Warp 里所有 32 个线程都参与
// 每次折叠一半: 16 -> 8 -> 4 -> 2 -> 1
// 除2等同于右移1位
for (int offset = warpSize / 2; offset > 0; offset >>= 1) {
// "当前值" + "offset个偏移量位置"
val += __shfl_down_sync(0xffffffff, val, offset);
}
return val;
}

Warp Reduce Max

辅助函数:Warp 内求最大值

1
2
3
4
5
6
7
8
template <typename T>
__device__ __forceinline__ T warpReduceMax(T val) {
for (int offset = 16; offset > 0; offset >>= 1) {
T temp = __shfl_down_sync(0xffffffff, val, offset);
if (temp > val) val = temp;
}
return val;
}

2. Block 级归约操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
template <typename T>
__device__ __forceinline__ T blockReduceSum(T val) {
if (blockDim.x <= warpSize) {
return warpReduceSum(val);
}
// 静态分配共享内存,用来存放每个 Warp 的总和
// 一个 Block 最多 32 个 Warps
static __shared__ T shared[32];

int lane = threadIdx.x % warpSize; // Warp内排第几 (0-31)
int wid = threadIdx.x / warpSize; // 第几个 Warp

// 每个 Warp 内部先归并
val = warpReduceSum(val);

// Warp 0把结果写到 Shared Memory
if (lane == 0) {
shared[wid] = val;
}

// 等待所有 Warp 写完
__syncthreads();

// 由第一个 Warp (warp 0) 负责把 Shared Memory 里的数加起来
// 只有前 (blockDim.x / 32) 个线程需要读取数据
val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0;

// 第一个 Warp 再次做 Warp Reduce
if (wid == 0) {
val = warpReduceSum(val);
}

return val;
}

3. 矩阵迹的计算

迹的核函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
template <typename T>
__global__ void traceKernel(const T* __restrict__ input, size_t rows, size_t cols, size_t n, T* out) {
// 这里的 stride 是整个 Grid 一次能处理的数据量
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
size_t stride = blockDim.x * gridDim.x; // 一个 Grid 的总线程数

// local sum for each thread in register
T local_sum = 0;

// 防止一个 Grid 处理不完
for (size_t i = idx; i < n; i += stride) {
// i * cols + i 是对角线元素的物理索引
local_sum += input[i * cols + i];
}

// Block 内部归并
local_sum = blockReduceSum(local_sum);

// 只有 Block 的 Thread 0 负责写回 Global Memory
// 极大减少 atomicAdd 的竞争
if (threadIdx.x == 0) {
atomicAdd(out, local_sum);
}
}

Host 端迹的计算函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
template <typename T>
T trace(const std::vector<T>& h_input, size_t rows, size_t cols) {
if (rows == 0 || cols == 0) {
return T(0);
}
size_t n = std::min(rows, cols);
size_t total_elems = rows * cols;

T* d_input = nullptr;
T* d_out = nullptr;
RUNTIME_CHECK(cudaMalloc(&d_input, total_elems * sizeof(T)));
RUNTIME_CHECK(cudaMalloc(&d_out, sizeof(T)));
RUNTIME_CHECK(cudaMemcpy(d_input, h_input.data(), total_elems * sizeof(T), cudaMemcpyHostToDevice));
RUNTIME_CHECK(cudaMemset(d_out, 0, sizeof(T)));

dim3 block(BLOCK_SIZE);
dim3 grid((n + BLOCK_SIZE - 1) / BLOCK_SIZE);
traceKernel<T><<<grid, block>>>(d_input, rows, cols, n, d_out);
RUNTIME_CHECK(cudaGetLastError());
RUNTIME_CHECK(cudaDeviceSynchronize());

T h_out = T(0);
RUNTIME_CHECK(cudaMemcpy(&h_out, d_out, sizeof(T), cudaMemcpyDeviceToHost));

RUNTIME_CHECK(cudaFree(d_input));
RUNTIME_CHECK(cudaFree(d_out));
return h_out;
}

4. 注意力机制 - 类型转换

float 类型转换

1
2
3
4
5
6
7
8
9
10
11
12
template <typename T>
__device__ __forceinline__ float to_float(T v);

template <>
__device__ __forceinline__ float to_float<float>(float v) {
return v;
}

template <>
__device__ __forceinline__ float to_float<half>(half v) {
return __half2float(v);
}

反向类型转换

1
2
3
4
5
6
7
8
9
10
11
12
template <typename T>
__device__ __forceinline__ T from_float(float v);

template <>
__device__ __forceinline__ float from_float<float>(float v) {
return v;
}

template <>
__device__ __forceinline__ half from_float<half>(float v) {
return __float2half(v);
}

5. 朴素注意力实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
template <typename T>
__global__ void native_attention_kernel(const T* q, const T* k, const T* v, T* o,
int batch_size, int target_seq_len, int src_seq_len,
int query_heads, int kv_heads, int head_dim, bool is_causal) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t o_elems = batch_size * target_seq_len * query_heads * head_dim;
if (idx >= o_elems) {
return;
}

// 线性索引 -> (b, t, qh, d)
int d = static_cast<int>(idx % head_dim);
size_t tmp = idx / head_dim;
int qh = static_cast<int>(tmp % query_heads);
tmp /= query_heads;
int t = static_cast<int>(tmp % target_seq_len);
int b = static_cast<int>(tmp / target_seq_len);

// GQA: query head -> kv head
int kv_h = (qh * kv_heads) / query_heads;

const float scale = 1.0f / sqrtf(static_cast<float>(head_dim));

// Step 1: 计算 max score (数值稳定)
float max_score = -INFINITY;
for (int sk = 0; sk < src_seq_len; ++sk) {
if (is_causal && sk > t) {
continue;
}
const T* q_ptr = q + (((b * target_seq_len + t) * query_heads + qh) * head_dim);
const T* k_ptr = k + (((b * src_seq_len + sk) * kv_heads + kv_h) * head_dim);
float dot = 0.0f;
for (int i = 0; i < head_dim; ++i) {
dot += to_float(q_ptr[i]) * to_float(k_ptr[i]);
}
float score = dot * scale;
if (score > max_score) {
max_score = score;
}
}

// Step 2: 计算 softmax 分母
float denom = 0.0f;
for (int sk = 0; sk < src_seq_len; ++sk) {
if (is_causal && sk > t) {
continue;
}
const T* q_ptr = q + (((b * target_seq_len + t) * query_heads + qh) * head_dim);
const T* k_ptr = k + (((b * src_seq_len + sk) * kv_heads + kv_h) * head_dim);
float dot = 0.0f;
for (int i = 0; i < head_dim; ++i) {
dot += to_float(q_ptr[i]) * to_float(k_ptr[i]);
}
float score = dot * scale;
denom += expf(score - max_score);
}

if (denom == 0.0f) {
o[idx] = from_float<T>(0.0f);
return;
}

// Step 3: 计算加权和 (输出元素)
float out_val = 0.0f;
for (int sk = 0; sk < src_seq_len; ++sk) {
if (is_causal && sk > t) {
continue;
}
const T* q_ptr = q + (((b * target_seq_len + t) * query_heads + qh) * head_dim);
const T* k_ptr = k + (((b * src_seq_len + sk) * kv_heads + kv_h) * head_dim);
const T* v_ptr = v + (((b * src_seq_len + sk) * kv_heads + kv_h) * head_dim);
float dot = 0.0f;
for (int i = 0; i < head_dim; ++i) {
dot += to_float(q_ptr[i]) * to_float(k_ptr[i]);
}
float score = dot * scale;
float w = expf(score - max_score) / denom;
out_val += w * to_float(v_ptr[d]);
}

o[idx] = from_float<T>(out_val);
}

6. Flash Attention V1

核函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
template <typename T>
__global__ void flash_attention_v1_kernel(const T* Q, const T* K, const T* V, T* O,
int batch_size, int target_seq_len, int src_seq_len,
int query_heads, int kv_heads, int head_dim,
bool is_causal, float scale) {
// 每个 Block 处理一个 Query 向量 (1 x head_dim)
const int q_idx = blockIdx.x;
const int tid = threadIdx.x; // 线程处理 d 维中的一个分量

const int q_vecs = batch_size * target_seq_len * query_heads;
if (q_idx >= q_vecs || tid >= head_dim) {
return;
}

// 线性索引 -> (b, t, qh)
int tmp = q_idx;
int qh = tmp % query_heads;
tmp /= query_heads;
int t = tmp % target_seq_len;
int b = tmp / target_seq_len;

// GQA: query head -> kv head
int kv_h = (qh * kv_heads) / query_heads;

const T* q_ptr = Q + q_idx * head_dim;

// 将当前 Block 负责的 Q 向量加载到寄存器
float q_val = to_float(q_ptr[tid]);

// Online Softmax 统计量
float m_i = -INFINITY;
float l_i = 0.0f;
float o_i = 0.0f;

// Shared Memory 存放 K、V 以及点积结果
extern __shared__ float s_mem[];
float* s_k = s_mem; // head_dim
float* s_v = s_mem + head_dim; // head_dim
float* s_scalar = s_mem + 2 * head_dim; // 1 float

for (int j = 0; j < src_seq_len; ++j) {
if (is_causal && j > t) {
continue;
}

const T* k_ptr = K + (((b * src_seq_len + j) * kv_heads + kv_h) * head_dim);
const T* v_ptr = V + (((b * src_seq_len + j) * kv_heads + kv_h) * head_dim);

// 加载 K 和 V 到 Shared Memory
s_k[tid] = to_float(k_ptr[tid]);
s_v[tid] = to_float(v_ptr[tid]);
__syncthreads();

// 计算点积 S = Q * K^T (使用 blockReduceSum 提高精度)
float score = q_val * s_k[tid];
float dot = blockReduceSum(score);
if (tid == 0) {
s_scalar[0] = dot * scale;
}
__syncthreads();
dot = s_scalar[0];

// Online Softmax 更新
float m_prev = m_i;
float l_prev = l_i;

m_i = fmaxf(m_prev, dot);
float p_prev = expf(m_prev - m_i);
float p_curr = expf(dot - m_i);
l_i = l_prev * p_prev + p_curr;
o_i = o_i * p_prev + p_curr * s_v[tid];

__syncthreads();
}

if (l_i > 0.0f) {
o_i /= l_i;
} else {
o_i = 0.0f;
}

O[q_idx * head_dim + tid] = from_float<T>(o_i);
}

Host 端 Flash Attention 函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
template <typename T>
void flashAttention(const std::vector<T>& h_q, const std::vector<T>& h_k,
const std::vector<T>& h_v, std::vector<T>& h_o,
int batch_size, int target_seq_len, int src_seq_len,
int query_heads, int kv_heads, int head_dim, bool is_causal) {
if (batch_size <= 0 || target_seq_len <= 0 || src_seq_len <= 0 ||
query_heads <= 0 || kv_heads <= 0 || head_dim <= 0) {
h_o.clear();
return;
}

const size_t q_elems = batch_size * target_seq_len * query_heads * head_dim;
const size_t k_elems = batch_size * src_seq_len * kv_heads * head_dim;
const size_t v_elems = k_elems;
const size_t o_elems = q_elems;

h_o.resize(o_elems);

T* d_q = nullptr;
T* d_k = nullptr;
T* d_v = nullptr;
T* d_o = nullptr;
RUNTIME_CHECK(cudaMalloc(&d_q, q_elems * sizeof(T)));
RUNTIME_CHECK(cudaMalloc(&d_k, k_elems * sizeof(T)));
RUNTIME_CHECK(cudaMalloc(&d_v, v_elems * sizeof(T)));
RUNTIME_CHECK(cudaMalloc(&d_o, o_elems * sizeof(T)));

// Host -> Device: 拷贝 Q/K/V 到 GPU,并清空输出
RUNTIME_CHECK(cudaMemcpy(d_q, h_q.data(), q_elems * sizeof(T), cudaMemcpyHostToDevice));
RUNTIME_CHECK(cudaMemcpy(d_k, h_k.data(), k_elems * sizeof(T), cudaMemcpyHostToDevice));
RUNTIME_CHECK(cudaMemcpy(d_v, h_v.data(), v_elems * sizeof(T), cudaMemcpyHostToDevice));
RUNTIME_CHECK(cudaMemset(d_o, 0, o_elems * sizeof(T)));

// Device kernel launch: flash attention v1 (每个 block 处理一个 query 向量)
dim3 block(head_dim); // 一个 Block 处理全部 head_dim 维度
dim3 grid(batch_size * target_seq_len * query_heads);
size_t shared_bytes = (2 * head_dim + 1) * sizeof(float);
float scale = 1.0f / sqrtf(static_cast<float>(head_dim));
flash_attention_v1_kernel<T><<<grid, block, shared_bytes>>>(
d_q, d_k, d_v, d_o,
batch_size, target_seq_len, src_seq_len,
query_heads, kv_heads, head_dim, is_causal, scale);

RUNTIME_CHECK(cudaGetLastError());
RUNTIME_CHECK(cudaDeviceSynchronize());

// Device -> Host: 拷回输出结果
RUNTIME_CHECK(cudaMemcpy(h_o.data(), d_o, o_elems * sizeof(T), cudaMemcpyDeviceToHost));

RUNTIME_CHECK(cudaFree(d_q));
RUNTIME_CHECK(cudaFree(d_k));
RUNTIME_CHECK(cudaFree(d_v));
RUNTIME_CHECK(cudaFree(d_o));
}

7. 显式模板实例化

1
2
3
4
5
6
7
8
9
10
// REQUIRED FOR LINKING WITH TESTER.O
// DO NOT MODIFY THIS SECTION
template int trace<int>(const std::vector<int>&, size_t, size_t);
template float trace<float>(const std::vector<float>&, size_t, size_t);
template void flashAttention<float>(const std::vector<float>&, const std::vector<float>&,
const std::vector<float>&, std::vector<float>&,
int, int, int, int, int, int, bool);
template void flashAttention<half>(const std::vector<half>&, const std::vector<half>&,
const std::vector<half>&, std::vector<half>&,
int, int, int, int, int, int, bool);
  • Title: 算子实例
  • Author: Ikko
  • Created at : 2026-02-05 13:56:54
  • Updated at : 2026-02-05 14:00:07
  • Link: http://ikko-debug.github.io/2026/02/05/suanzi/
  • License: This work is licensed under CC BY-NC-SA 4.0.
Comments