diff --git a/.ci/scripts/export_model_artifact.sh b/.ci/scripts/export_model_artifact.sh index f19df233628..4fb4a04f296 100755 --- a/.ci/scripts/export_model_artifact.sh +++ b/.ci/scripts/export_model_artifact.sh @@ -418,7 +418,8 @@ if [ "$MODEL_NAME" = "qwen3_5_moe" ]; then TORCHINDUCTOR_CACHE_DIR="$INDUCTOR_CACHE" \ python -m executorch.examples.models.qwen3_5_moe.export \ --prequantized "$LOCAL_MODEL_DIR" \ - --output-dir "${OUTPUT_DIR}" + --output-dir "${OUTPUT_DIR}" \ + --moe-activation-dtype int8 echo "::endgroup::" test -f "${OUTPUT_DIR}/model.pte" diff --git a/backends/cuda/tests/test_fused_moe.py b/backends/cuda/tests/test_fused_moe.py index e23832b89ea..bbc351bc47b 100644 --- a/backends/cuda/tests/test_fused_moe.py +++ b/backends/cuda/tests/test_fused_moe.py @@ -31,6 +31,7 @@ from executorch.backends.cuda.triton.kernels.fused_moe import ( fused_moe as triton_fused_moe, fused_moe_batched as triton_fused_moe_batched, + fused_moe_batched_gemm_int8 as triton_fused_moe_batched_int8, moe_align_block_size, ) from executorch.exir import ( @@ -212,6 +213,11 @@ def _run_cpp_runner(runner_path, pte_path, ptd_path, input_files, output_base): class TestFusedMoE(unittest.TestCase): + # TODO: migrate from manual max_abs/max_ref relative checks to + # torch.allclose(atol=, rtol=). Current tests use per-tensor-max relative + # error which is looser than per-element allclose — need to calibrate atol + # for INT4 quantization noise floor across random weight magnitudes. + def setUp(self): if not torch.cuda.is_available(): self.skipTest("CUDA is not available") @@ -487,6 +493,152 @@ def test_e2e_cpp_runner(self): ) +class TestFusedMoEBatchedInt8(unittest.TestCase): + """Correctness tests for the INT8 dynamic-activation batched MoE kernel.""" + + INT8_TEST_CONFIGS = [ + (42, 8, 64, 32, 4, 2, 32, "8tok_small"), + (7, 16, 64, 32, 8, 4, 32, "16tok_8exp_top4"), + (13, 32, 128, 64, 8, 2, 64, "32tok_gs64"), + (55, 64, 64, 32, 4, 2, 32, "64tok"), + (99, 128, 128, 64, 8, 2, 32, "128tok"), + (0, 256, 128, 64, 8, 2, 32, "256tok"), + ] + + def test_int8_correctness(self): + """INT8 batched kernel matches reference across M values.""" + for ( + seed, + M, + hidden, + intermediate, + num_experts, + top_k, + gs, + desc, + ) in self.INT8_TEST_CONFIGS: + with self.subTest(desc=desc): + torch.manual_seed(seed) + x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda") + w1_weight = torch.randn( + num_experts, + 2 * intermediate, + hidden, + dtype=torch.bfloat16, + device="cuda", + ) + w2_weight = torch.randn( + num_experts, + hidden, + intermediate, + dtype=torch.bfloat16, + device="cuda", + ) + w1, w1s = _quantize_weights_int4(w1_weight.cpu(), gs) + w2, w2s = _quantize_weights_int4(w2_weight.cpu(), gs) + w1, w1s, w2, w2s = w1.cuda(), w1s.cuda(), w2.cuda(), w2s.cuda() + + scores = torch.randn(M, num_experts, device="cuda") + topk_weights, topk_ids = torch.topk(scores, top_k, dim=-1) + topk_weights = topk_weights.softmax(dim=-1).float() + + out_int8 = triton_fused_moe_batched_int8( + x, + w1, + w1s, + w2, + w2s, + topk_weights, + topk_ids, + top_k, + num_experts, + gs, + ) + + w1_dq = _dequantize_int4(w1.cpu(), w1s.cpu(), gs).cuda() + w2_dq = _dequantize_int4(w2.cpu(), w2s.cpu(), gs).cuda() + ref = _reference_moe(x, w1_dq, w2_dq, topk_weights, topk_ids, top_k) + + diff = (out_int8.float() - ref.float()).abs().max().item() + rel = diff / (ref.float().abs().max().item() + 1e-10) + self.assertLess( + rel, + 0.10, + f"{desc}: relative diff {rel:.4f} (abs {diff:.6f})", + ) + + def test_int8_matches_bf16_batched(self): + """INT8 batched output is close to BF16 batched output.""" + for ( + seed, + M, + hidden, + intermediate, + num_experts, + top_k, + gs, + desc, + ) in self.INT8_TEST_CONFIGS: + with self.subTest(desc=desc): + torch.manual_seed(seed) + x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda") + w1_weight = torch.randn( + num_experts, + 2 * intermediate, + hidden, + dtype=torch.bfloat16, + device="cuda", + ) + w2_weight = torch.randn( + num_experts, + hidden, + intermediate, + dtype=torch.bfloat16, + device="cuda", + ) + w1, w1s = _quantize_weights_int4(w1_weight.cpu(), gs) + w2, w2s = _quantize_weights_int4(w2_weight.cpu(), gs) + w1, w1s, w2, w2s = w1.cuda(), w1s.cuda(), w2.cuda(), w2s.cuda() + + scores = torch.randn(M, num_experts, device="cuda") + topk_weights, topk_ids = torch.topk(scores, top_k, dim=-1) + topk_weights = topk_weights.softmax(dim=-1).float() + + out_bf16 = triton_fused_moe_batched( + x, + w1, + w1s, + w2, + w2s, + topk_weights, + topk_ids, + top_k, + num_experts, + gs, + ) + + out_int8 = triton_fused_moe_batched_int8( + x, + w1, + w1s, + w2, + w2s, + topk_weights, + topk_ids, + top_k, + num_experts, + gs, + ) + + diff = (out_int8.float() - out_bf16.float()).abs().max().item() + rel = diff / (out_bf16.float().abs().max().item() + 1e-10) + self.assertLess( + rel, + 0.15, + f"{desc}: int8 vs bf16 relative diff {rel:.4f} (abs {diff:.6f})", + ) + + class TestMoeAlignBlockSize(unittest.TestCase): def setUp(self): if not torch.cuda.is_available(): diff --git a/backends/cuda/triton/kernels/fused_moe.py b/backends/cuda/triton/kernels/fused_moe.py index 2f9119efb55..9dae95cb503 100644 --- a/backends/cuda/triton/kernels/fused_moe.py +++ b/backends/cuda/triton/kernels/fused_moe.py @@ -25,7 +25,8 @@ Fused MoE Triton Kernels for ExecuTorch CUDA Backend. Performs grouped GEMM for Mixture-of-Experts with INT4 weight-only -quantization (W4A16). Two kernel variants: +quantization (W4A16) or INT4 weights + INT8 activations (W4A8). +Two kernel families (bf16 and int8), each with two variants: - fused_moe: vec-mat per-pair kernel for decode (M=1). - fused_moe_batched_gemm: token-sorted tensor-core kernel for prefill (M>>1). @@ -703,6 +704,148 @@ def _fused_moe_batched_kernel( tl.store(c_ptrs, acc.to(compute_type), mask=n_mask[None, :]) +# Autotune configs for batched INT8 GEMM1 (gate+up projection, W4A8). +_BATCHED_GEMM1_INT8_CONFIGS = [ + triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128}, num_warps=4, num_stages=3), + triton.Config( + {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}, num_warps=4, num_stages=2 + ), + triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_warps=4, num_stages=4), +] + + +@triton.autotune(configs=_BATCHED_GEMM1_INT8_CONFIGS, key=["N", "K"]) +@triton.jit +def _fused_moe_batched_int8_kernel( + # Pointers + A, # [M+1, K] bf16 activations (row M is zero-padding sentinel) + B, # [E, N, K//2] int8 packed INT4 weights + C, # [num_tokens_post_padded, N] bf16 output (sorted order) + B_scale, # [E, N, K//group_size] bf16 scales + sorted_token_ids, # [num_tokens_post_padded] int64 pair indices + expert_ids, # [num_expert_blocks] int64 + # Dimensions + N: tl.constexpr, + K: tl.constexpr, + # Strides + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk, + stride_bsn, + # Config + top_k: tl.constexpr, + group_size: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + compute_type: tl.constexpr, +): + """Batched GEMM1 (gate+up) with INT8 tensor cores (W4A8). + + Dynamically quantizes bf16 activations to INT8 per-row per-tile, + dequantizes INT4 weights to INT8 (skipping bf16), and uses + tl.dot(int8, int8) → int32 accumulation with per-tile float32 rescale. + """ + pid = tl.program_id(0) + num_n_blocks = tl.cdiv(N, BLOCK_SIZE_N) + expert_block_idx = pid // num_n_blocks + n_block = pid % num_n_blocks + + expert_id = tl.load(expert_ids + expert_block_idx).to(tl.int64) + + offs_m = expert_block_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + pair_ids = tl.load(sorted_token_ids + offs_m) + token_ids = pair_ids // top_k + + offs_n = n_block * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + n_mask = offs_n < N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = A + token_ids[:, None] * stride_am + offs_k[None, :] * stride_ak + + b_ptrs = ( + B + + expert_id * stride_be + + (offs_k[:, None] // 2) * stride_bk + + offs_n[None, :] * stride_bn + ) + b_shifter = (offs_k[:, None] % 2) * 4 + + # Float32 accumulator for cross-tile summation (rescaled per tile) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k_step in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + k_remaining = K - k_step * BLOCK_SIZE_K + k_mask = offs_k < k_remaining + + # Load bf16 activation tile [BLOCK_M, BLOCK_K] + a_bf16 = tl.load(a_ptrs, mask=k_mask[None, :], other=0.0) + + # Per-row dynamic INT8 quantization + a_f32 = a_bf16.to(tl.float32) + a_absmax = tl.max(tl.abs(a_f32), axis=1) # [BLOCK_M] + a_scale = a_absmax / 127.0 + 1e-12 # avoid division by zero + a_scaled = a_f32 / a_scale[:, None] + a_int8 = (a_scaled + tl.where(a_scaled >= 0, 0.5, -0.5)).to(tl.int8) + + # Load and unpack INT4 weights to INT8 [BLOCK_K, BLOCK_N] + b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0) + b = (b >> b_shifter) & 0xF + b_int8 = (b - 8).to(tl.int8) # symmetric dequant to [-8, 7] + + # Per-group weight scale + if BLOCK_SIZE_K <= group_size: + group_idx = (BLOCK_SIZE_K * k_step) // group_size + scale_ptrs = ( + B_scale + + expert_id * stride_bse + + offs_n[None, :] * stride_bsn + + group_idx * stride_bsk + ) + b_scale = tl.load(scale_ptrs, mask=n_mask[None, :], other=0.0).to( + tl.float32 + ) + else: + scale_ptrs = ( + B_scale + + expert_id * stride_bse + + offs_n[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size) * stride_bsk + ) + b_scale = tl.load( + scale_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0 + ).to(tl.float32) + + if BLOCK_SIZE_K <= group_size: + # INT8 tensor core GEMM: [BLOCK_M, BLOCK_K] @ [BLOCK_K, BLOCK_N] → int32 + dot_i32 = tl.dot(a_int8, b_int8) + # b_scale is [1, BLOCK_N], broadcast + acc += dot_i32.to(tl.float32) * a_scale[:, None] * b_scale + else: + # Multi-group tile: dequantize weights per group, use float matmul + b_dequant = (b_int8.to(tl.float32) * b_scale).to(compute_type) + acc += ( + tl.dot(a_int8.to(compute_type), b_dequant).to(tl.float32) + * a_scale[:, None] + ) + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + + # Write output in sorted order [BLOCK_M, BLOCK_N] + c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + tl.store(c_ptrs, acc.to(compute_type), mask=n_mask[None, :]) + + @triton.autotune(configs=_BATCHED_GEMM2_CONFIGS, key=["N", "K"]) @triton.jit def _fused_moe_silu_batched_kernel( @@ -834,6 +977,159 @@ def _fused_moe_silu_batched_kernel( tl.store(c_ptrs, acc.to(compute_type), mask=n_mask[None, :]) +# Autotune configs for batched INT8 GEMM2 (down projection + SiLU, W4A8). +_BATCHED_GEMM2_INT8_CONFIGS = [ + triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128}, num_warps=4, num_stages=2), + triton.Config( + {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}, num_warps=4, num_stages=2 + ), + triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_warps=4, num_stages=4), +] + + +@triton.autotune(configs=_BATCHED_GEMM2_INT8_CONFIGS, key=["N", "K"]) +@triton.jit +def _fused_moe_silu_batched_int8_kernel( + # Pointers + A, # [num_tokens_post_padded, 2*inter] bf16 GEMM1 output (sorted order) + B, # [E, N, K//2] int8 packed INT4 weights + C, # [M*top_k + 1, N] bf16 output (scatter to original pair order) + B_scale, # [E, N, K//group_size] bf16 scales + sorted_token_ids, # [num_tokens_post_padded] int64 pair indices + expert_ids, # [num_expert_blocks] int64 + topk_weights, # [M*top_k] float32 router weights (flat) + # Dimensions + N: tl.constexpr, + K: tl.constexpr, # intermediate_size + num_pairs, # M * top_k (for clamping sentinel weight lookups) + # Strides + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk, + stride_bsn, + # Config + top_k: tl.constexpr, + group_size: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + compute_type: tl.constexpr, +): + """Batched GEMM2 with fused SiLU, INT8 tensor cores, and scatter-back (W4A8). + + SiLU(gate)*up is computed in float32, then dynamically quantized to INT8 + per-row per-tile. INT4 weights are dequantized directly to INT8. + tl.dot(int8, int8) → int32, with per-tile float32 rescale. + """ + pid = tl.program_id(0) + num_n_blocks = tl.cdiv(N, BLOCK_SIZE_N) + expert_block_idx = pid // num_n_blocks + n_block = pid % num_n_blocks + + expert_id = tl.load(expert_ids + expert_block_idx).to(tl.int64) + + offs_m = expert_block_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + pair_ids = tl.load(sorted_token_ids + offs_m) + + offs_n = n_block * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + n_mask = offs_n < N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # A pointers: gate at [0, K), up at [K, 2K) + a_gate_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + a_up_ptrs = a_gate_ptrs + K * stride_ak + + b_ptrs = ( + B + + expert_id * stride_be + + (offs_k[:, None] // 2) * stride_bk + + offs_n[None, :] * stride_bn + ) + b_shifter = (offs_k[:, None] % 2) * 4 + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k_step in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + k_remaining = K - k_step * BLOCK_SIZE_K + k_mask = offs_k < k_remaining + + # Load gate and up tiles, apply SiLU in float32 + gate = tl.load(a_gate_ptrs, mask=k_mask[None, :], other=0.0).to(tl.float32) + up = tl.load(a_up_ptrs, mask=k_mask[None, :], other=0.0) + silu_out = gate * tl.sigmoid(gate) * up.to(tl.float32) # [BLOCK_M, BLOCK_K] + + # Per-row dynamic INT8 quantization of SiLU output + a_absmax = tl.max(tl.abs(silu_out), axis=1) # [BLOCK_M] + a_scale = a_absmax / 127.0 + 1e-12 + a_scaled = silu_out / a_scale[:, None] + a_int8 = (a_scaled + tl.where(a_scaled >= 0, 0.5, -0.5)).to(tl.int8) + + # Load and unpack INT4 weights to INT8 [BLOCK_K, BLOCK_N] + b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0) + b = (b >> b_shifter) & 0xF + b_int8 = (b - 8).to(tl.int8) + + # Per-group weight scale + if BLOCK_SIZE_K <= group_size: + group_idx = (BLOCK_SIZE_K * k_step) // group_size + scale_ptrs = ( + B_scale + + expert_id * stride_bse + + offs_n[None, :] * stride_bsn + + group_idx * stride_bsk + ) + b_scale = tl.load(scale_ptrs, mask=n_mask[None, :], other=0.0).to( + tl.float32 + ) + else: + scale_ptrs = ( + B_scale + + expert_id * stride_bse + + offs_n[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size) * stride_bsk + ) + b_scale = tl.load( + scale_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0 + ).to(tl.float32) + + if BLOCK_SIZE_K <= group_size: + # INT8 tensor core GEMM: [BLOCK_M, BLOCK_K] @ [BLOCK_K, BLOCK_N] → int32 + dot_i32 = tl.dot(a_int8, b_int8) + acc += dot_i32.to(tl.float32) * a_scale[:, None] * b_scale + else: + # Multi-group tile: dequantize weights per group, use float matmul + b_dequant = (b_int8.to(tl.float32) * b_scale).to(compute_type) + acc += ( + tl.dot(a_int8.to(compute_type), b_dequant).to(tl.float32) + * a_scale[:, None] + ) + + a_gate_ptrs += BLOCK_SIZE_K * stride_ak + a_up_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + + # Apply router weights per row + safe_pair_ids = tl.minimum(pair_ids, num_pairs - 1) + weights = tl.load(topk_weights + safe_pair_ids) + is_valid = pair_ids < num_pairs + weights = tl.where(is_valid, weights, 0.0) + acc = acc * weights[:, None] + + # Scatter to original pair order + scatter_ids = tl.where(is_valid, pair_ids, num_pairs) + c_ptrs = C + scatter_ids[:, None] * stride_cm + offs_n[None, :] * stride_cn + tl.store(c_ptrs, acc.to(compute_type), mask=n_mask[None, :]) + + # --------------------------------------------------------------------------- # Batched triton_op wrapper # --------------------------------------------------------------------------- @@ -967,6 +1263,134 @@ def _fused_moe_batched_gemm_fake( return torch.empty_like(hidden_states) +@triton_op("triton::fused_moe_batched_gemm_int8", mutates_args={}) +def fused_moe_batched_gemm_int8( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + num_experts: int, + group_size: int, +) -> torch.Tensor: + """Batched W4A8 GEMM1 + GEMM2+SiLU with INT8 tensor cores.""" + M, K = hidden_states.shape + N1 = w1.shape[1] + intermediate = N1 // 2 + N2 = w2.shape[1] + num_pairs = M * top_k + BLOCK_M = _BATCHED_BLOCK_M + + sorted_token_ids, expert_ids, _ = moe_align_block_size( + topk_ids, BLOCK_M, num_experts + ) + max_padded = sorted_token_ids.shape[0] + num_expert_blocks = expert_ids.shape[0] + + hidden_padded = torch.cat( + [ + hidden_states, + torch.zeros(1, K, dtype=hidden_states.dtype, device=hidden_states.device), + ], + dim=0, + ) + + topk_weights_flat = topk_weights.reshape(-1) + + cache1 = torch.empty( + max_padded, + N1, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + def grid1(meta): + return (num_expert_blocks * triton.cdiv(N1, meta["BLOCK_SIZE_N"]),) + + wrap_triton(_fused_moe_batched_int8_kernel)[grid1]( + hidden_padded, + w1, + cache1, + w1_scale, + sorted_token_ids, + expert_ids, + N=N1, + K=K, + stride_am=hidden_padded.stride(0), + stride_ak=hidden_padded.stride(1), + stride_be=w1.stride(0), + stride_bk=w1.stride(2), + stride_bn=w1.stride(1), + stride_cm=cache1.stride(0), + stride_cn=cache1.stride(1), + stride_bse=w1_scale.stride(0), + stride_bsk=w1_scale.stride(2), + stride_bsn=w1_scale.stride(1), + top_k=top_k, + group_size=group_size, + BLOCK_SIZE_M=BLOCK_M, + compute_type=tl.bfloat16, + ) + + out_buf = torch.zeros( + num_pairs + 1, + N2, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + def grid2(meta): + return (num_expert_blocks * triton.cdiv(N2, meta["BLOCK_SIZE_N"]),) + + wrap_triton(_fused_moe_silu_batched_int8_kernel)[grid2]( + cache1, + w2, + out_buf, + w2_scale, + sorted_token_ids, + expert_ids, + topk_weights_flat, + N=N2, + K=intermediate, + num_pairs=num_pairs, + stride_am=cache1.stride(0), + stride_ak=cache1.stride(1), + stride_be=w2.stride(0), + stride_bk=w2.stride(2), + stride_bn=w2.stride(1), + stride_cm=out_buf.stride(0), + stride_cn=out_buf.stride(1), + stride_bse=w2_scale.stride(0), + stride_bsk=w2_scale.stride(2), + stride_bsn=w2_scale.stride(1), + top_k=top_k, + group_size=group_size, + BLOCK_SIZE_M=BLOCK_M, + compute_type=tl.bfloat16, + ) + + return out_buf[:num_pairs].view(M, top_k, N2).sum(dim=1) + + +@fused_moe_batched_gemm_int8.register_fake +def _fused_moe_batched_gemm_int8_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + num_experts: int, + group_size: int, +) -> torch.Tensor: + return torch.empty_like(hidden_states) + + def fused_moe_batched( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -978,8 +1402,22 @@ def fused_moe_batched( top_k: int, num_experts: int, group_size: int, + activation_dtype: str = "bf16", ) -> torch.Tensor: - """Convenience wrapper for benchmarking (same as fused_moe_batched_gemm).""" + """Convenience wrapper that dispatches to bf16 or int8 batched kernels.""" + if activation_dtype == "int8": + return fused_moe_batched_gemm_int8( + hidden_states, + w1, + w1_scale, + w2, + w2_scale, + topk_weights, + topk_ids, + top_k, + num_experts, + group_size, + ) return fused_moe_batched_gemm( hidden_states, w1, diff --git a/examples/models/qwen3_5_moe/export.py b/examples/models/qwen3_5_moe/export.py index ac6c112c08c..8e12d0236dd 100644 --- a/examples/models/qwen3_5_moe/export.py +++ b/examples/models/qwen3_5_moe/export.py @@ -535,11 +535,12 @@ def _apply_turboquant(model, config): # --------------------------------------------------------------------------- -def _set_batched_moe(model, enabled): +def _set_batched_moe(model, enabled, moe_activation_dtype="bf16"): """Toggle batched tensor-core MoE kernel for all MoE layers.""" for layer in model.layers: if hasattr(layer, "mlp") and hasattr(layer.mlp, "experts"): layer.mlp.experts.use_batched_moe = enabled + layer.mlp.experts.moe_activation_dtype = moe_activation_dtype def export_and_lower(model, config, args): @@ -782,7 +783,8 @@ def _export_cuda(model, config, args): # chunk_gated_delta_rule with CHUNK_SIZE=64) for the full range of sequence # lengths. Smaller examples cause AOTI to bake in intermediate buffer sizes # that reject longer prompts at runtime. - _set_batched_moe(model, True) + moe_activation_dtype = getattr(args, "moe_activation_dtype", "bf16") + _set_batched_moe(model, True, moe_activation_dtype=moe_activation_dtype) print("Exporting prefill method...") example_prefill_len = config.max_seq_len - 1 @@ -946,6 +948,12 @@ def main(): # noqa: C901 action="store_true", help="Disable split-K (flash-decoding) SDPA for decode; use tiled SDPA instead.", ) + parser.add_argument( + "--moe-activation-dtype", + choices=["bf16", "int8"], + default="bf16", + help="MoE activation dtype for prefill only. Decode always uses bf16. bf16 (default): W4A16 batched GEMM. int8: W4A8 with INT8 tensor cores.", + ) args = parser.parse_args() if args.model_id: diff --git a/examples/models/qwen3_5_moe/model.py b/examples/models/qwen3_5_moe/model.py index 81c093f5652..f187ddb8c15 100644 --- a/examples/models/qwen3_5_moe/model.py +++ b/examples/models/qwen3_5_moe/model.py @@ -479,6 +479,7 @@ def __init__(self, config): self.hidden_size = config.hidden_size self.group_size = 32 self.use_batched_moe = False + self.moe_activation_dtype = "bf16" self.w1_weight = nn.Parameter( torch.empty( @@ -497,6 +498,19 @@ def __init__(self, config): def forward(self, x, expert_weights, expert_indices, top_k): if self.use_batched_moe: + if self.moe_activation_dtype == "int8": + return torch.ops.triton.fused_moe_batched_gemm_int8( + x, + self.w1, + self.w1_scale, + self.w2, + self.w2_scale, + expert_weights, + expert_indices, + top_k, + self.num_experts, + self.group_size, + ) return torch.ops.triton.fused_moe_batched_gemm( x, self.w1,