Skip to content

Commit f5fe229

Browse files
committed
refactor(arch): improve codebase architecture depth
Centralize configuration and unify algorithms: 1. Tiling Configuration: - Add ForwardTilingConfig and BackwardTilingConfig structs - Replace hardcoded block sizes with centralized config - Document why forward/backward use different tiling 2. Online Softmax: - Add OnlineSoftmaxState::update_with_rescale() method - Refactor forward kernel to use unified algorithm - Ensure numerical consistency across forward/backward 3. Validation: - Add is_supported_head_dim() for centralized checking - Update API validation to use single source of truth 4. Documentation: - Update CONTEXT.md with new architecture details - Add test file comments explaining test vs production paths
1 parent 8d3ad79 commit f5fe229

8 files changed

Lines changed: 133 additions & 34 deletions

File tree

CONTEXT.md

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ FlashAttention 中的计算单位。一块 Q、K 或 V 的行数据,大小由
1717

1818
支持增量式更新,无需存储完整 softmax matrix。
1919

20+
使用 `impl::OnlineSoftmaxState` 结构体统一实现,确保前向和反向传播使用相同的数值算法。
21+
2022
### Matmul Operations
2123
Tile 级别的矩阵乘法原语:
2224
- `matmul_ABt`: C = A @ Bᵀ(attention score 计算)
@@ -36,22 +38,32 @@ CUDA 设备函数,在 GPU 上并行执行。FlashAttention 的 kernel 是参
3638
┌─────────────────────────────────────────────────┐
3739
│ cuflash::flash_attention_forward/backward │ ← 高级 API
3840
├─────────────────────────────────────────────────┤
39-
│ cuflash::kernels::* │ ← Kernel 原语(公开)
41+
│ cuflash::kernels::* │ ← Kernel 原语(公开,用于测试
4042
│ - online_softmax_init/update/finalize │
4143
│ - matmul_ABt, matmul_AB, matmul_AB_acc, AtB │
4244
│ - load_tile, store_tile │
4345
├─────────────────────────────────────────────────┤
4446
│ src/kernels/impl/* │ ← 实现细节(内部)
4547
│ - OnlineSoftmaxState (device struct) │
48+
│ - ForwardTilingConfig, BackwardTilingConfig │
4649
│ - __device__ matmul functions │
4750
│ - __device__ tile I/O functions │
4851
└─────────────────────────────────────────────────┘
4952
```
5053

54+
## Tiling Configuration
55+
56+
Tiling 参数集中定义在 `src/kernels/impl/tile_io.cuh` 中:
57+
58+
- **ForwardTilingConfig**: 前向传播使用较大的 block(64x64,head_dim=128 时用 32x32)
59+
- **BackwardTilingConfig**: 反向传播使用较小的 block(64x64,head_dim=128 时用 16x32)
60+
61+
差异原因:反向传播需要在 shared memory 中存储更多梯度张量(dQ, dK, dV),因此需要更保守的 tiling 策略。
62+
5163
## Key Invariants
5264

5365
1. **Tensor Layout**: `[batch_size, num_heads, seq_len, head_dim]` — 不可变
54-
2. **Supported head_dim**: 32, 64, 128 — 由 kernel 模板实例化决定
66+
2. **Supported head_dim**: 32, 64, 128 — 由 `impl::is_supported_head_dim()` 集中检查
5567
3. **Data Types**: FP32 (float) and FP16 (half) — 内部计算始终用 float
5668
4. **Stream Safety**: 所有 CUDA 操作使用显式 stream 参数
5769

@@ -60,3 +72,4 @@ CUDA 设备函数,在 GPU 上并行执行。FlashAttention 的 kernel 是参
6072
1. **Depth over Shallow**: Kernel utilities 有公开接口,测试不穿透实现细节
6173
2. **Primitive Decomposition**: 复杂操作分解为可组合的原语
6274
3. **Template for Performance**: M, N, K 作为编译期模板参数,确保 kernel 优化
75+
4. **Centralized Configuration**: Tiling 配置和支持的 head_dim 值集中定义,避免分散

src/api/flash_attention_api.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "cuflash/flash_attention.h"
2+
#include "impl/tile_io.cuh" // For is_supported_head_dim
23

34
namespace cuflash {
45

@@ -55,8 +56,8 @@ static FlashAttentionError validate_common_params(const void* Q, const void* K,
5556
return FlashAttentionError::INVALID_DIMENSION;
5657
}
5758

58-
// Check supported head_dim values
59-
if (head_dim != 32 && head_dim != 64 && head_dim != 128) {
59+
// Check supported head_dim values (centralized check)
60+
if (!impl::is_supported_head_dim(head_dim)) {
6061
return FlashAttentionError::UNSUPPORTED_HEAD_DIM;
6162
}
6263

src/backward/flash_attention_backward_typed.cu

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <float.h>
55

66
#include "cuflash/flash_attention.h"
7+
#include "impl/online_softmax.cuh"
78
#include "impl/tile_io.cuh"
89
#include "kernel_launch_utils.cuh"
910
#include "workspace_utils.cuh"
@@ -371,10 +372,11 @@ FlashAttentionError launch_flash_attention_backward_typed<float>(
371372
const float* Q, const float* K, const float* V, const float* O, const float* L, const float* dO,
372373
float* dQ, float* dK, float* dV, int batch_size, int num_heads, int seq_len, int head_dim,
373374
float scale, bool causal, cudaStream_t stream) {
374-
constexpr int BLOCK_M = 64;
375-
constexpr int BLOCK_N = 64;
376-
constexpr int BLOCK_M_HD128 = 16;
377-
constexpr int BLOCK_N_HD128 = 32;
375+
using Config = impl::BackwardTilingConfig;
376+
constexpr int BLOCK_M = Config::BLOCK_M;
377+
constexpr int BLOCK_N = Config::BLOCK_N;
378+
constexpr int BLOCK_M_HD128 = Config::BLOCK_M_HD128;
379+
constexpr int BLOCK_N_HD128 = Config::BLOCK_N_HD128;
378380

379381
int batch_heads = batch_size * num_heads;
380382

@@ -526,10 +528,11 @@ FlashAttentionError launch_flash_attention_backward_typed<half>(
526528
const half* Q, const half* K, const half* V, const half* O, const half* L, const half* dO,
527529
half* dQ, half* dK, half* dV, int batch_size, int num_heads, int seq_len, int head_dim,
528530
float scale, bool causal, cudaStream_t stream) {
529-
constexpr int BLOCK_M = 64;
530-
constexpr int BLOCK_N = 64;
531-
constexpr int BLOCK_M_HD128 = 16;
532-
constexpr int BLOCK_N_HD128 = 32;
531+
using Config = impl::BackwardTilingConfig;
532+
constexpr int BLOCK_M = Config::BLOCK_M;
533+
constexpr int BLOCK_N = Config::BLOCK_N;
534+
constexpr int BLOCK_M_HD128 = Config::BLOCK_M_HD128;
535+
constexpr int BLOCK_N_HD128 = Config::BLOCK_N_HD128;
533536

534537
int batch_heads = batch_size * num_heads;
535538

src/forward/flash_attention_forward_typed.cu

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <float.h>
55

66
#include "cuflash/flash_attention.h"
7+
#include "impl/online_softmax.cuh"
78
#include "impl/tile_io.cuh"
89
#include "kernel_launch_utils.cuh"
910

@@ -105,13 +106,15 @@ __global__ void __launch_bounds__(128)
105106
if (q_start + row >= seq_len)
106107
continue;
107108

109+
// Compute row max for this KV block
108110
float row_max = -INFINITY;
109111
for (int j = 0; j < BLOCK_N; j++) {
110112
if (kv_start + j < seq_len) {
111113
row_max = fmaxf(row_max, S_tile[row * BLOCK_N + j]);
112114
}
113115
}
114116

117+
// Compute row sum (exp) and convert scores to probabilities
115118
float row_sum = 0.0f;
116119
for (int j = 0; j < BLOCK_N; j++) {
117120
if (kv_start + j < seq_len) {
@@ -122,30 +125,31 @@ __global__ void __launch_bounds__(128)
122125
}
123126
}
124127

125-
// Update online softmax state
126-
float m_old = m_tile[row];
127-
float l_old = l_tile[row];
128-
float m_new = fmaxf(m_old, row_max);
129-
float l_new = l_old * expf(m_old - m_new) + row_sum * expf(row_max - m_new);
128+
// Update online softmax state using unified algorithm
129+
impl::OnlineSoftmaxState state;
130+
state.m = m_tile[row];
131+
state.l = l_tile[row];
132+
133+
float rescale_existing, scale_new;
134+
state.update_with_rescale(row_max, row_sum, rescale_existing, scale_new);
130135

131136
// Rescale existing O
132-
float rescale = expf(m_old - m_new);
133137
for (int d = 0; d < HEAD_DIM; d++) {
134-
O_tile[row * HEAD_DIM + d] *= rescale;
138+
O_tile[row * HEAD_DIM + d] *= rescale_existing;
135139
}
136140

137141
// Add contribution from this block: P @ V
138-
float p_scale = expf(row_max - m_new);
139142
for (int d = 0; d < HEAD_DIM; d++) {
140143
float sum = 0.0f;
141144
for (int j = 0; j < BLOCK_N; j++) {
142145
sum += S_tile[row * BLOCK_N + j] * V_tile[j * HEAD_DIM + d];
143146
}
144-
O_tile[row * HEAD_DIM + d] += sum * p_scale;
147+
O_tile[row * HEAD_DIM + d] += sum * scale_new;
145148
}
146149

147-
m_tile[row] = m_new;
148-
l_tile[row] = l_new;
150+
// Store updated state
151+
m_tile[row] = state.m;
152+
l_tile[row] = state.l;
149153
}
150154
__syncthreads();
151155
}
@@ -199,10 +203,11 @@ template<>
199203
FlashAttentionError launch_flash_attention_forward_typed<float>(
200204
const float* Q, const float* K, const float* V, float* O, float* L, int batch_size,
201205
int num_heads, int seq_len, int head_dim, float scale, bool causal, cudaStream_t stream) {
202-
constexpr int BLOCK_M = 64;
203-
constexpr int BLOCK_N = 64;
204-
constexpr int BLOCK_M_HD128 = 32;
205-
constexpr int BLOCK_N_HD128 = 32;
206+
using Config = impl::ForwardTilingConfig;
207+
constexpr int BLOCK_M = Config::BLOCK_M;
208+
constexpr int BLOCK_N = Config::BLOCK_N;
209+
constexpr int BLOCK_M_HD128 = Config::BLOCK_M_HD128;
210+
constexpr int BLOCK_N_HD128 = Config::BLOCK_N_HD128;
206211

207212
const int batch_heads = batch_size * num_heads;
208213
const int num_q_blocks = (seq_len + BLOCK_M - 1) / BLOCK_M;
@@ -271,10 +276,11 @@ template<>
271276
FlashAttentionError launch_flash_attention_forward_typed<half>(
272277
const half* Q, const half* K, const half* V, half* O, half* L, int batch_size, int num_heads,
273278
int seq_len, int head_dim, float scale, bool causal, cudaStream_t stream) {
274-
constexpr int BLOCK_M = 64;
275-
constexpr int BLOCK_N = 64;
276-
constexpr int BLOCK_M_HD128 = 32;
277-
constexpr int BLOCK_N_HD128 = 32;
279+
using Config = impl::ForwardTilingConfig;
280+
constexpr int BLOCK_M = Config::BLOCK_M;
281+
constexpr int BLOCK_N = Config::BLOCK_N;
282+
constexpr int BLOCK_M_HD128 = Config::BLOCK_M_HD128;
283+
constexpr int BLOCK_N_HD128 = Config::BLOCK_N_HD128;
278284

279285
const int batch_heads = batch_size * num_heads;
280286
const int num_q_blocks = (seq_len + BLOCK_M - 1) / BLOCK_M;

src/kernels/impl/online_softmax.cuh

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,25 @@ struct OnlineSoftmaxState {
3737

3838
/// Get the normalization factor for final output
3939
__device__ __forceinline__ float get_normalizer() const { return 1.0f / l; }
40+
41+
/// Update from local block statistics and return the rescale factor for existing O.
42+
/// This combines update() with computing the necessary scaling for output accumulation.
43+
/// @param new_m max value in the new block
44+
/// @param new_l sum of exp(x - new_m) in the new block
45+
/// @param rescale_existing output: factor to multiply existing O by
46+
/// @param scale_new output: factor to multiply new P@V by
47+
__device__ __forceinline__ void update_with_rescale(float new_m, float new_l,
48+
float& rescale_existing, float& scale_new) {
49+
float m_old = m;
50+
float l_old = l;
51+
float m_new = fmaxf(m_old, new_m);
52+
l = l_old * expf(m_old - m_new) + new_l * expf(new_m - m_new);
53+
m = m_new;
54+
55+
// Compute rescaling factors
56+
rescale_existing = expf(m_old - m_new); // For existing O
57+
scale_new = expf(new_m - m_new); // For new P @ V contribution
58+
}
4059
};
4160

4261
// =============================================================================

src/kernels/impl/tile_io.cuh

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -422,14 +422,54 @@ __device__ __forceinline__ void matmul_AtB(const float* __restrict__ A, // KxM
422422
}
423423
}
424424

425-
// Tiling configuration
426-
struct TilingConfig {
425+
// =============================================================================
426+
// Tiling Configuration
427+
// =============================================================================
428+
// Centralized tiling configuration for FlashAttention kernels.
429+
// Different configurations for forward and backward passes due to memory
430+
// constraints in backward pass requiring smaller blocks.
431+
432+
/// Tiling configuration for forward pass.
433+
/// Uses larger blocks for better memory throughput.
434+
struct ForwardTilingConfig {
435+
// Standard block sizes for head_dim 32 and 64
436+
static constexpr int BLOCK_M = 64; // Q block rows
437+
static constexpr int BLOCK_N = 64; // K/V block rows
438+
439+
// Smaller blocks for head_dim 128 (shared memory constraint)
440+
static constexpr int BLOCK_M_HD128 = 32;
441+
static constexpr int BLOCK_N_HD128 = 32;
442+
443+
static constexpr int NUM_THREADS = 128;
444+
static constexpr int WARP_SIZE = 32;
445+
};
446+
447+
/// Tiling configuration for backward pass.
448+
/// Uses smaller blocks to accommodate additional gradient tensors in shared memory.
449+
struct BackwardTilingConfig {
450+
// Standard block sizes for head_dim 32 and 64
427451
static constexpr int BLOCK_M = 64; // Q block rows
428452
static constexpr int BLOCK_N = 64; // K/V block rows
429-
static constexpr int BLOCK_K = 64; // Head dimension tile
453+
454+
// Smaller blocks for head_dim 128 (more aggressive due to dQ, dK, dV)
455+
static constexpr int BLOCK_M_HD128 = 16;
456+
static constexpr int BLOCK_N_HD128 = 32;
457+
430458
static constexpr int NUM_THREADS = 128;
431459
static constexpr int WARP_SIZE = 32;
432460
};
433461

462+
/// Supported head dimensions - single source of truth.
463+
/// Used for validation and kernel dispatch.
464+
inline constexpr int SUPPORTED_HEAD_DIMS[] = {32, 64, 128};
465+
466+
/// Check if a head_dim value is supported.
467+
inline constexpr bool is_supported_head_dim(int head_dim) {
468+
return head_dim == 32 || head_dim == 64 || head_dim == 128;
469+
}
470+
471+
// Legacy alias for backward compatibility
472+
using TilingConfig = ForwardTilingConfig;
473+
434474
} // namespace impl
435475
} // namespace cuflash

tests/unit/test_matmul.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
// Matmul Unit Tests
22
// Tests for cuflash::kernels::matmul_* operations
3+
//
4+
// NOTE: These tests validate the standalone kernel API for matmul operations.
5+
// The production forward/backward kernels use impl::matmul_* directly in
6+
// shared memory for better performance. The kernels::matmul_* API provides
7+
// a testable interface for the same underlying algorithms.
8+
//
9+
// Test coverage: The numerical correctness of impl::matmul_* is indirectly
10+
// validated through the end-to-end tests in tests/integration/.
311

412
#include <gtest/gtest.h>
513
#if CUFLASH_ENABLE_RAPIDCHECK

tests/unit/test_tile_io.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
// Tile I/O Unit Tests
22
// Tests for cuflash::kernels::load_tile and store_tile operations
3+
//
4+
// NOTE: These tests validate the standalone kernel API for tile I/O operations.
5+
// The production forward/backward kernels use impl::load_tile_to_shared and
6+
// impl::store_tile_from_shared directly for better performance (avoiding
7+
// kernel launch overhead for each tile). The kernels::* API provides a
8+
// testable interface for the same underlying algorithms.
9+
//
10+
// Test coverage: The correctness of impl::* tile functions is validated through
11+
// end-to-end tests in tests/integration/.
312

413
#include <gtest/gtest.h>
514
#if CUFLASH_ENABLE_RAPIDCHECK

0 commit comments

Comments
 (0)