1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
| template <typename T> void flashAttention(const std::vector<T>& h_q, const std::vector<T>& h_k, const std::vector<T>& h_v, std::vector<T>& h_o, int batch_size, int target_seq_len, int src_seq_len, int query_heads, int kv_heads, int head_dim, bool is_causal) { if (batch_size <= 0 || target_seq_len <= 0 || src_seq_len <= 0 || query_heads <= 0 || kv_heads <= 0 || head_dim <= 0) { h_o.clear(); return; }
const size_t q_elems = batch_size * target_seq_len * query_heads * head_dim; const size_t k_elems = batch_size * src_seq_len * kv_heads * head_dim; const size_t v_elems = k_elems; const size_t o_elems = q_elems;
h_o.resize(o_elems);
T* d_q = nullptr; T* d_k = nullptr; T* d_v = nullptr; T* d_o = nullptr; RUNTIME_CHECK(cudaMalloc(&d_q, q_elems * sizeof(T))); RUNTIME_CHECK(cudaMalloc(&d_k, k_elems * sizeof(T))); RUNTIME_CHECK(cudaMalloc(&d_v, v_elems * sizeof(T))); RUNTIME_CHECK(cudaMalloc(&d_o, o_elems * sizeof(T)));
RUNTIME_CHECK(cudaMemcpy(d_q, h_q.data(), q_elems * sizeof(T), cudaMemcpyHostToDevice)); RUNTIME_CHECK(cudaMemcpy(d_k, h_k.data(), k_elems * sizeof(T), cudaMemcpyHostToDevice)); RUNTIME_CHECK(cudaMemcpy(d_v, h_v.data(), v_elems * sizeof(T), cudaMemcpyHostToDevice)); RUNTIME_CHECK(cudaMemset(d_o, 0, o_elems * sizeof(T)));
dim3 block(head_dim); dim3 grid(batch_size * target_seq_len * query_heads); size_t shared_bytes = (2 * head_dim + 1) * sizeof(float); float scale = 1.0f / sqrtf(static_cast<float>(head_dim)); flash_attention_v1_kernel<T><<<grid, block, shared_bytes>>>( d_q, d_k, d_v, d_o, batch_size, target_seq_len, src_seq_len, query_heads, kv_heads, head_dim, is_causal, scale);
RUNTIME_CHECK(cudaGetLastError()); RUNTIME_CHECK(cudaDeviceSynchronize());
RUNTIME_CHECK(cudaMemcpy(h_o.data(), d_o, o_elems * sizeof(T), cudaMemcpyDeviceToHost));
RUNTIME_CHECK(cudaFree(d_q)); RUNTIME_CHECK(cudaFree(d_k)); RUNTIME_CHECK(cudaFree(d_v)); RUNTIME_CHECK(cudaFree(d_o)); }
|