Skip to content

Commit 9822d2e

Browse files
Daniel Shenclaude
andcommitted
test(moe): fix stale unit tests broken by lazy DeepEP buffer + packed-param requires_grad
Two L0 unit tests were stale relative to earlier branch code changes: - test_grouped_experts_deepep_token_dispatcher_init asserted init_token_dispatcher eagerly calls _init_deepep_buffer, but buffer allocation is now lazy (deferred to FusedDispatch.forward) — the revert that fixed the single-node load-time OOM. Assert it is NOT called. - ExpertParallel._partition_fn now constructs nn.Parameter(..., requires_grad=...) so non-floating packed mxfp4 params (int8 / e8m0) don't trip the default requires_grad=True. The test's stub Parameter didn't accept/store requires_grad; add it (also unblocks the requires_grad-preservation test). Both fixes verified: tests/unit_tests/moe now 450 passed, 0 failed. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Daniel Shen <dshen@crusoe.ai>
1 parent 3676cd4 commit 9822d2e

2 files changed

Lines changed: 6 additions & 2 deletions

File tree

tests/unit_tests/moe/test_experts.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,10 @@ def test_grouped_experts_deepep_token_dispatcher_init(self, moe_config):
759759
assert hasattr(experts, "token_dispatcher")
760760
assert experts.ep_size == 2
761761
assert experts.ep_rank == 0
762-
mock_init_buffer.assert_called_once_with(mock_mesh.get_group.return_value)
762+
# The DeepEP NVSHMEM buffer is allocated lazily (in FusedDispatch.forward),
763+
# not eagerly in init_token_dispatcher — the revert that fixed the single-node
764+
# load-time OOM. So init_token_dispatcher must NOT call _init_deepep_buffer.
765+
mock_init_buffer.assert_not_called()
763766

764767
def test_grouped_experts_deepep_apply_bias_no_bias(self, moe_config):
765768
"""Test _apply_bias method with no bias."""

tests/unit_tests/moe/test_parallelizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,9 @@ def _install_torch_and_layers_stubs(monkeypatch):
8484
nn_stub = types.ModuleType("torch.nn")
8585

8686
class Parameter:
87-
def __init__(self, data=None):
87+
def __init__(self, data=None, requires_grad=True):
8888
self.data = data
89+
self.requires_grad = requires_grad
8990

9091
class Module:
9192
pass

0 commit comments

Comments
 (0)