Skip to content

Commit 34cf3c0

Browse files
LessUpCopilot
andauthored
refactor(kernels): deepen primitive standalone layer (#1)
- centralize standalone primitive wrapper validation and launch helpers - unify FP32/FP16 conversion through TypeAdapter - add online_softmax and matmul regression coverage Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 7427743 commit 34cf3c0

12 files changed

Lines changed: 581 additions & 177 deletions
Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
# Primitive Layer Deepening Implementation Plan
2+
3+
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
4+
5+
**Goal:** 收敛 `online_softmax``matmul``tile_io` 的 shallow standalone primitive layer,把 host-side validation、kernel launch、type conversion 统一到更深的内部 module。
6+
7+
**Architecture:** 保持 public header 暂时稳定,不在这一轮同时引入 spec/API breakage。先把重复的 host-side validation、dynamic shared-memory launch、kernel launch error handling 折叠进一个内部 helper seam,再把 FP32/FP16 转换收敛到单一 type adapter policy;这样 wrapper 继续兼容,但 implementation 更深、locality 更好。
8+
9+
**Tech Stack:** CUDA C++17, CMake presets, GoogleTest
10+
11+
---
12+
13+
## File Map
14+
15+
- Create: `src/kernels/primitive_api_utils.cuh`
16+
- Create: `src/kernels/impl/type_adapter.cuh`
17+
- Modify: `src/kernels/online_softmax.cu`
18+
- Modify: `src/kernels/matmul.cu`
19+
- Modify: `src/kernels/tile_io.cu`
20+
- Modify: `src/kernels/impl/tile_io.cuh`
21+
- Modify: `src/forward/flash_attention_forward_typed.cu`
22+
- Modify: `src/backward/flash_attention_backward_typed.cu`
23+
- Test: `tests/unit/test_online_softmax.cu`
24+
- Test: `tests/unit/test_matmul.cu`
25+
- Test: `tests/unit/test_dtype.cu`
26+
27+
### Task 1: Centralize primitive wrapper validation and launch helpers
28+
29+
**Files:**
30+
- Create: `src/kernels/primitive_api_utils.cuh`
31+
- Modify: `src/kernels/online_softmax.cu`
32+
- Modify: `src/kernels/matmul.cu`
33+
- Modify: `src/kernels/tile_io.cu`
34+
35+
- [ ] **Step 1: Write the failing test**
36+
37+
Use existing characterization tests as the guard surface:
38+
39+
```cpp
40+
EXPECT_EQ(kernels::online_softmax_forward(nullptr, d_valid, d_valid, 4, 4, 2, stream),
41+
FlashAttentionError::NULL_POINTER);
42+
EXPECT_EQ(kernels::matmul_ABt<64, 64, 32>(nullptr, d_valid, d_valid, 1.0f, stream),
43+
FlashAttentionError::NULL_POINTER);
44+
EXPECT_EQ(kernels::load_tile<64, 64>(d_src, d_dst, -1, 0, 128, 128, 128, stream),
45+
FlashAttentionError::INVALID_DIMENSION);
46+
```
47+
48+
- [ ] **Step 2: Run test to verify it fails**
49+
50+
Run:
51+
52+
```bash
53+
ctest --preset release --output-on-failure -R "OnlineSoftmaxTest|MatmulTest|TileIOTest"
54+
```
55+
56+
Expected: 当前机器无 `nvcc`,会卡在 configure/build;在有 CUDA 环境时,这一步用于证明 refactor 前后行为回归可观测。
57+
58+
- [ ] **Step 3: Write minimal implementation**
59+
60+
Create one internal helper seam:
61+
62+
```cpp
63+
inline FlashAttentionError validate_non_null(std::initializer_list<const void*> ptrs);
64+
inline FlashAttentionError validate_positive_dimensions(std::initializer_list<int> dims);
65+
inline FlashAttentionError validate_tile_window(int row_start, int col_start,
66+
int max_rows, int max_cols, int stride);
67+
template <typename KernelFunc>
68+
inline FlashAttentionError prepare_kernel_launch(KernelFunc kernel, size_t smem_size);
69+
inline FlashAttentionError finish_kernel_launch();
70+
```
71+
72+
Then replace file-local helpers like:
73+
74+
```cpp
75+
FlashAttentionError err = detail::validate_non_null({A, B, C});
76+
if (err != FlashAttentionError::SUCCESS) return err;
77+
```
78+
79+
- [ ] **Step 4: Run test to verify it passes**
80+
81+
Run:
82+
83+
```bash
84+
ctest --preset release --output-on-failure -R "OnlineSoftmaxTest|MatmulTest|TileIOTest"
85+
```
86+
87+
Expected: 在有 CUDA 环境时,既有 null/dimension tests 继续 PASS。
88+
89+
- [ ] **Step 5: Commit**
90+
91+
```bash
92+
git add src/kernels/primitive_api_utils.cuh src/kernels/online_softmax.cu src/kernels/matmul.cu src/kernels/tile_io.cu
93+
git commit -m "refactor(kernels): centralize primitive wrapper helpers"
94+
```
95+
96+
### Task 2: Unify FP32/FP16 conversion policy
97+
98+
**Files:**
99+
- Create: `src/kernels/impl/type_adapter.cuh`
100+
- Modify: `src/kernels/impl/tile_io.cuh`
101+
- Modify: `src/forward/flash_attention_forward_typed.cu`
102+
- Modify: `src/backward/flash_attention_backward_typed.cu`
103+
- Test: `tests/unit/test_dtype.cu`
104+
105+
- [ ] **Step 1: Write the failing test**
106+
107+
Use the current FP16 characterization tests:
108+
109+
```cpp
110+
TEST(DTypeTest, FP16ForwardMatchesFP32);
111+
TEST(DTypeTest, FP16BackwardMatchesFP32);
112+
```
113+
114+
Keep the finite-gradient checks as explicit acceptance criteria:
115+
116+
```cpp
117+
EXPECT_TRUE(std::isfinite(__half2float(h_dQ[i])));
118+
EXPECT_TRUE(std::isfinite(__half2float(h_dK[i])));
119+
EXPECT_TRUE(std::isfinite(__half2float(h_dV[i])));
120+
```
121+
122+
- [ ] **Step 2: Run test to verify it fails**
123+
124+
Run:
125+
126+
```bash
127+
ctest --preset release --output-on-failure -R DTypeTest
128+
```
129+
130+
Expected: 当前机器依旧被 `nvcc` 阻塞;在有 CUDA 环境时,这一步用于确认 refactor 没把 FP16/FP32 路径改坏。
131+
132+
- [ ] **Step 3: Write minimal implementation**
133+
134+
Create one trait:
135+
136+
```cpp
137+
template <typename T>
138+
struct TypeAdapter;
139+
140+
template <>
141+
struct TypeAdapter<float> {
142+
__device__ static float to_compute(float value) { return value; }
143+
__device__ static float from_compute(float value) { return value; }
144+
};
145+
146+
template <>
147+
struct TypeAdapter<half> {
148+
__device__ static float to_compute(half value) { return __half2float(value); }
149+
__device__ static half from_compute(float value) { return __float2half(value); }
150+
};
151+
```
152+
153+
Then replace scattered conversions like:
154+
155+
```cpp
156+
sum += impl::TypeAdapter<InputT>::to_compute(dO_row[d]) *
157+
impl::TypeAdapter<InputT>::to_compute(O_row[d]);
158+
L_ptr[global_row] =
159+
impl::TypeAdapter<InputT>::from_compute(m_tile[row] + logf(l_tile[row]));
160+
```
161+
162+
- [ ] **Step 4: Run test to verify it passes**
163+
164+
Run:
165+
166+
```bash
167+
ctest --preset release --output-on-failure -R DTypeTest
168+
```
169+
170+
Expected: 在有 CUDA 环境时,FP16/FP32 tests PASS。
171+
172+
- [ ] **Step 5: Commit**
173+
174+
```bash
175+
git add src/kernels/impl/type_adapter.cuh src/kernels/impl/tile_io.cuh src/forward/flash_attention_forward_typed.cu src/backward/flash_attention_backward_typed.cu
176+
git commit -m "refactor(kernels): unify primitive type conversion"
177+
```
178+
179+
### Task 3: Preserve and extend the regression surface
180+
181+
**Files:**
182+
- Modify: `tests/unit/test_online_softmax.cu`
183+
- Modify: `tests/unit/test_matmul.cu`
184+
185+
- [ ] **Step 1: Write the failing test**
186+
187+
Keep the two key regressions:
188+
189+
```cpp
190+
TEST_F(OnlineSoftmaxTest, Forward_MultiBlockCrossWarpMatchesReference);
191+
TEST_F(OnlineSoftmaxTest, FinalizeNullNormalizerReturnsError);
192+
TEST_F(MatmulTest, ABt_HeadDim128_LargeTile);
193+
```
194+
195+
- [ ] **Step 2: Run test to verify it fails**
196+
197+
Run:
198+
199+
```bash
200+
ctest --preset release --output-on-failure -R "OnlineSoftmaxTest|MatmulTest"
201+
```
202+
203+
Expected: 在有 CUDA 环境时,旧实现会在 cross-warp/multi-block correctness 和 large-smem launch 上失败。
204+
205+
- [ ] **Step 3: Write minimal implementation**
206+
207+
No new production API. Just keep tests aligned with the bug fixes:
208+
209+
```cpp
210+
EXPECT_NEAR(h_output[i], h_output_expected[i], 1e-4f);
211+
ASSERT_EQ(err, FlashAttentionError::SUCCESS);
212+
```
213+
214+
- [ ] **Step 4: Run test to verify it passes**
215+
216+
Run:
217+
218+
```bash
219+
ctest --preset release --output-on-failure -R "OnlineSoftmaxTest|MatmulTest"
220+
```
221+
222+
Expected: 在有 CUDA 环境时,new regressions PASS。
223+
224+
- [ ] **Step 5: Commit**
225+
226+
```bash
227+
git add tests/unit/test_online_softmax.cu tests/unit/test_matmul.cu
228+
git commit -m "test(kernels): lock down primitive wrapper regressions"
229+
```
230+
231+
### Task 4: Final hygiene and verification
232+
233+
**Files:**
234+
- Modify: all touched files
235+
236+
- [ ] **Step 1: Format modified files**
237+
238+
Run:
239+
240+
```bash
241+
clang-format -i src/kernels/primitive_api_utils.cuh src/kernels/impl/type_adapter.cuh src/kernels/impl/tile_io.cuh src/kernels/online_softmax.cu src/kernels/matmul.cu src/kernels/tile_io.cu src/forward/flash_attention_forward_typed.cu src/backward/flash_attention_backward_typed.cu tests/unit/test_online_softmax.cu tests/unit/test_matmul.cu
242+
```
243+
244+
- [ ] **Step 2: Run diff sanity checks**
245+
246+
Run:
247+
248+
```bash
249+
git --no-pager diff --check
250+
git --no-pager diff --stat
251+
```
252+
253+
Expected: no whitespace errors; diff concentrated in primitive-layer files.
254+
255+
- [ ] **Step 3: Run repo verification**
256+
257+
Run:
258+
259+
```bash
260+
cmake --preset release
261+
cmake --build --preset release
262+
ctest --preset release --output-on-failure
263+
```
264+
265+
Expected: 当前机器因缺 `nvcc` 仍会阻塞在 configure;有 CUDA 环境后,这三步必须全绿。
266+
267+
- [ ] **Step 4: Commit**
268+
269+
```bash
270+
git add src tests docs/superpowers/plans/2026-05-22-primitive-layer-deepening.md
271+
git commit -m "refactor(kernels): deepen standalone primitive layer"
272+
```

src/backward/flash_attention_backward_typed.cu

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ __global__ void __launch_bounds__(128)
2828
float sum = 0.0f;
2929
#pragma unroll
3030
for (int d = 0; d < HEAD_DIM; d++) {
31-
sum += impl::to_float(dO_row[d]) * impl::to_float(O_row[d]);
31+
sum += impl::TypeAdapter<InputT>::to_compute(dO_row[d]) *
32+
impl::TypeAdapter<InputT>::to_compute(O_row[d]);
3233
}
3334

3435
D[batch_head_idx * seq_len + row_idx] = sum;
@@ -81,7 +82,9 @@ __global__ void __launch_bounds__(128)
8182
}
8283
for (int i = tid; i < BLOCK_M; i += num_threads) {
8384
int global_idx = q_start + i;
84-
L_tile[i] = (global_idx < seq_len) ? impl::to_float(L_ptr[global_idx]) : 0.0f;
85+
L_tile[i] = (global_idx < seq_len)
86+
? impl::TypeAdapter<InputT>::to_compute(L_ptr[global_idx])
87+
: 0.0f;
8588
D_tile[i] = (global_idx < seq_len) ? D_ptr[global_idx] : 0.0f;
8689
}
8790
__syncthreads();
@@ -205,7 +208,9 @@ __global__ void __launch_bounds__(128)
205208

206209
for (int i = tid; i < BLOCK_M; i += num_threads) {
207210
int global_idx = q_start + i;
208-
L_tile[i] = (global_idx < seq_len) ? impl::to_float(L_ptr[global_idx]) : 0.0f;
211+
L_tile[i] = (global_idx < seq_len)
212+
? impl::TypeAdapter<InputT>::to_compute(L_ptr[global_idx])
213+
: 0.0f;
209214
D_tile[i] = (global_idx < seq_len) ? D_ptr[global_idx] : 0.0f;
210215
}
211216
__syncthreads();

src/forward/flash_attention_forward_typed.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,12 @@ __global__ void __launch_bounds__(128)
163163
float l_inv = 1.0f / l_tile[row];
164164
for (int d = 0; d < HEAD_DIM; d++) {
165165
O_ptr[global_row * HEAD_DIM + d] =
166-
InputT(O_tile[row * HEAD_DIM + d] * l_inv); // Implicit float->half if needed
166+
impl::TypeAdapter<InputT>::from_compute(O_tile[row * HEAD_DIM + d] * l_inv);
167167
}
168168

169169
// Store logsumexp for backward pass
170-
L_ptr[global_row] = InputT(m_tile[row] + logf(l_tile[row]));
170+
L_ptr[global_row] =
171+
impl::TypeAdapter<InputT>::from_compute(m_tile[row] + logf(l_tile[row]));
171172
}
172173
}
173174

src/kernels/impl/online_softmax.cuh

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,12 @@ __device__ __forceinline__ float block_reduce_max(float val, float* shared) {
120120
for (int offset = 16; offset > 0; offset /= 2) {
121121
val = fmaxf(val, __shfl_xor_sync(0xffffffff, val, offset));
122122
}
123+
if (lane == 0) {
124+
shared[0] = val;
125+
}
123126
}
124-
125-
return val;
127+
__syncthreads();
128+
return shared[0];
126129
}
127130

128131
/// Block-level sum reduction using shared memory.
@@ -147,9 +150,12 @@ __device__ __forceinline__ float block_reduce_sum(float val, float* shared) {
147150
for (int offset = 16; offset > 0; offset /= 2) {
148151
val += __shfl_xor_sync(0xffffffff, val, offset);
149152
}
153+
if (lane == 0) {
154+
shared[0] = val;
155+
}
150156
}
151-
152-
return val;
157+
__syncthreads();
158+
return shared[0];
153159
}
154160

155161
} // namespace impl

0 commit comments

Comments
 (0)