44#include < float.h>
55
66#include " cuflash/flash_attention.h"
7+ #include " impl/online_softmax.cuh"
78#include " impl/tile_io.cuh"
89#include " kernel_launch_utils.cuh"
910
@@ -105,13 +106,15 @@ __global__ void __launch_bounds__(128)
105106 if (q_start + row >= seq_len)
106107 continue ;
107108
109+ // Compute row max for this KV block
108110 float row_max = -INFINITY ;
109111 for (int j = 0 ; j < BLOCK_N ; j++) {
110112 if (kv_start + j < seq_len) {
111113 row_max = fmaxf (row_max, S_tile[row * BLOCK_N + j]);
112114 }
113115 }
114116
117+ // Compute row sum (exp) and convert scores to probabilities
115118 float row_sum = 0 .0f ;
116119 for (int j = 0 ; j < BLOCK_N ; j++) {
117120 if (kv_start + j < seq_len) {
@@ -122,30 +125,31 @@ __global__ void __launch_bounds__(128)
122125 }
123126 }
124127
125- // Update online softmax state
126- float m_old = m_tile[row];
127- float l_old = l_tile[row];
128- float m_new = fmaxf (m_old, row_max);
129- float l_new = l_old * expf (m_old - m_new) + row_sum * expf (row_max - m_new);
128+ // Update online softmax state using unified algorithm
129+ impl::OnlineSoftmaxState state;
130+ state.m = m_tile[row];
131+ state.l = l_tile[row];
132+
133+ float rescale_existing, scale_new;
134+ state.update_with_rescale (row_max, row_sum, rescale_existing, scale_new);
130135
131136 // Rescale existing O
132- float rescale = expf (m_old - m_new);
133137 for (int d = 0 ; d < HEAD_DIM ; d++) {
134- O_tile[row * HEAD_DIM + d] *= rescale ;
138+ O_tile[row * HEAD_DIM + d] *= rescale_existing ;
135139 }
136140
137141 // Add contribution from this block: P @ V
138- float p_scale = expf (row_max - m_new);
139142 for (int d = 0 ; d < HEAD_DIM ; d++) {
140143 float sum = 0 .0f ;
141144 for (int j = 0 ; j < BLOCK_N ; j++) {
142145 sum += S_tile[row * BLOCK_N + j] * V_tile[j * HEAD_DIM + d];
143146 }
144- O_tile[row * HEAD_DIM + d] += sum * p_scale ;
147+ O_tile[row * HEAD_DIM + d] += sum * scale_new ;
145148 }
146149
147- m_tile[row] = m_new;
148- l_tile[row] = l_new;
150+ // Store updated state
151+ m_tile[row] = state.m ;
152+ l_tile[row] = state.l ;
149153 }
150154 __syncthreads ();
151155 }
@@ -199,10 +203,11 @@ template<>
199203FlashAttentionError launch_flash_attention_forward_typed<float >(
200204 const float * Q, const float * K, const float * V, float * O, float * L, int batch_size,
201205 int num_heads, int seq_len, int head_dim, float scale, bool causal, cudaStream_t stream) {
202- constexpr int BLOCK_M = 64 ;
203- constexpr int BLOCK_N = 64 ;
204- constexpr int BLOCK_M_HD128 = 32 ;
205- constexpr int BLOCK_N_HD128 = 32 ;
206+ using Config = impl::ForwardTilingConfig;
207+ constexpr int BLOCK_M = Config::BLOCK_M ;
208+ constexpr int BLOCK_N = Config::BLOCK_N ;
209+ constexpr int BLOCK_M_HD128 = Config::BLOCK_M_HD128 ;
210+ constexpr int BLOCK_N_HD128 = Config::BLOCK_N_HD128 ;
206211
207212 const int batch_heads = batch_size * num_heads;
208213 const int num_q_blocks = (seq_len + BLOCK_M - 1 ) / BLOCK_M ;
@@ -271,10 +276,11 @@ template<>
271276FlashAttentionError launch_flash_attention_forward_typed<half>(
272277 const half* Q, const half* K, const half* V, half* O, half* L, int batch_size, int num_heads,
273278 int seq_len, int head_dim, float scale, bool causal, cudaStream_t stream) {
274- constexpr int BLOCK_M = 64 ;
275- constexpr int BLOCK_N = 64 ;
276- constexpr int BLOCK_M_HD128 = 32 ;
277- constexpr int BLOCK_N_HD128 = 32 ;
279+ using Config = impl::ForwardTilingConfig;
280+ constexpr int BLOCK_M = Config::BLOCK_M ;
281+ constexpr int BLOCK_N = Config::BLOCK_N ;
282+ constexpr int BLOCK_M_HD128 = Config::BLOCK_M_HD128 ;
283+ constexpr int BLOCK_N_HD128 = Config::BLOCK_N_HD128 ;
278284
279285 const int batch_heads = batch_size * num_heads;
280286 const int num_q_blocks = (seq_len + BLOCK_M - 1 ) / BLOCK_M ;
0 commit comments