Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
68f98aa
fix(models): audit fp32 protected dtype casts
yuhezhang-ai Jun 16, 2026
0abf403
fix(gpt-oss): isolate attention sinks in fp32 holder
yuhezhang-ai Jun 16, 2026
0611261
fix(nemotron-v3): isolate mamba fp32 params
yuhezhang-ai Jun 16, 2026
a0e0317
fix(nemotron-v3): keep mamba D fp32
yuhezhang-ai Jun 16, 2026
4a73380
fix(distributed): keep fp32 holders unresharded in forward
yuhezhang-ai Jun 16, 2026
00f3e9b
fix(distributed): thread reshard through dtype-aware sharding
yuhezhang-ai Jun 16, 2026
b035704
fix(distributed): warn on pp reshard override
yuhezhang-ai Jun 16, 2026
9e1db7b
fix(gpt-oss): keep attention sinks resident dtype
yuhezhang-ai Jun 16, 2026
f623f08
Merge remote-tracking branch 'origin/main' into yuhez/fix/fp32-protec…
yuhezhang-ai Jun 22, 2026
6c82ec2
Merge branch 'main' into yuhez/fix/fp32-protected-tensor-audit
akoumpa Jun 23, 2026
3c516cf
fix(models): harden fp32 protected tensor gathers
yuhezhang-ai Jun 23, 2026
0d5f9f1
Merge remote-tracking branch 'origin/main' into yuhez/fix/fp32-protec…
yuhezhang-ai Jun 23, 2026
7bc4c56
fix(models): clone fp32 holder outputs after gather
yuhezhang-ai Jun 23, 2026
9dbf3d9
fix(distributed): reshard fp32 holder submodules normally
yuhezhang-ai Jun 23, 2026
2f115f3
Preserve Nemotron fp32-protected dtypes in HF export
yuhezhang-ai Jun 23, 2026
523aa92
fix(checkpoint): preserve qwen gdn fp32 export dtypes
yuhezhang-ai Jun 23, 2026
9965deb
Merge branch 'main' into yuhez/fix/fp32-protected-tensor-audit
akoumpa Jun 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions nemo_automodel/components/checkpoint/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,29 @@ def _normalize_dtype_mapping_to_state_dict_keys(
return normalized


def _apply_adapter_forced_dtype_mapping(
model: nn.Module,
state_dict: dict[str, torch.Tensor],
fqn_to_dtype_mapping: dict[str, str],
) -> dict[str, str]:
"""Let model adapters override original HF dtype metadata for export-only keys."""
adapter = getattr(model, "state_dict_adapter", None)
forced_dtype_mapping = getattr(adapter, "forced_hf_dtype_mapping", None)
if not callable(forced_dtype_mapping):
return fqn_to_dtype_mapping

forced = forced_dtype_mapping(state_dict)
if not forced:
return fqn_to_dtype_mapping

normalized = dict(fqn_to_dtype_mapping)
state_dict_key_set = set(state_dict)
for fqn, dtype_str in forced.items():
if fqn in state_dict_key_set:
normalized[fqn] = dtype_str
return normalized


def is_cloud_path(path: str) -> bool:
"""Check if path is a cloud storage path (MSC)."""
return path.startswith("msc://")
Expand Down Expand Up @@ -1351,6 +1374,7 @@ def _maybe_build_original_dtype_mapping(
normalized_dtype_mapping = _normalize_dtype_mapping_to_state_dict_keys(
dtype_mapping, list(state_dict.keys()), getattr(model, "base_model_prefix", None)
)
normalized_dtype_mapping = _apply_adapter_forced_dtype_mapping(model, state_dict, normalized_dtype_mapping)
return normalized_dtype_mapping or None

def _get_storage_writer(
Expand Down
9 changes: 6 additions & 3 deletions nemo_automodel/components/distributed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,12 @@ class FSDP2Config:
activation checkpointing.
defer_fsdp_grad_sync (bool): Defer FSDP gradient sync to final micro-batch.
reshard_after_forward (Optional[bool]): Override layer-level FSDP2 resharding.
If ``None`` (default), AutoModel reshards all but the last layer outside
pipeline parallelism. Set ``False`` for a ZeRO-2-like benchmark where
gathered parameters stay resident after forward.
``None`` preserves AutoModel's heuristic: pipeline-parallel layers do
not reshard after forward, while non-pipeline layers reshard all but
the last layer. Set ``False`` for a ZeRO-2-like benchmark where
gathered parameters stay resident after forward. Set ``True`` to force
resharding everywhere, including pipeline-parallel layers, which may
reduce throughput by adding per-microbatch all-gathers.
enable_async_tensor_parallel (bool): Enable async tensor parallelism via
``torch._inductor.config._micro_pipeline_tp``. Overlaps ReduceScatter with
compute in row-parallel layers. Requires ``sequence_parallel=True`` (forced
Expand Down
12 changes: 6 additions & 6 deletions nemo_automodel/components/distributed/mamba_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,17 +461,17 @@ def get_conv1d_bias(self) -> torch.Tensor:
return None
return self._slice_conv_param(self._mixer.conv1d.bias)

def get_dt_bias(self) -> torch.Tensor:
def get_dt_bias(self, dt_bias: torch.Tensor | None = None) -> torch.Tensor:
"""Slice ``dt_bias`` for the current CP rank."""
return self._slice_vector_param(self._mixer.dt_bias)
return self._slice_vector_param(self._mixer.dt_bias if dt_bias is None else dt_bias)

def get_A_log(self) -> torch.Tensor:
def get_A_log(self, A_log: torch.Tensor | None = None) -> torch.Tensor:
"""Slice ``A_log`` for the current CP rank."""
return self._slice_vector_param(self._mixer.A_log)
return self._slice_vector_param(self._mixer.A_log if A_log is None else A_log)

def get_D(self) -> torch.Tensor:
def get_D(self, D: torch.Tensor | None = None) -> torch.Tensor:
"""Slice ``D`` for the current CP rank."""
return self._slice_vector_param(self._mixer.D)
return self._slice_vector_param(self._mixer.D if D is None else D)

# ------------------------------------------------------------------ #
# Internal helpers #
Expand Down
78 changes: 65 additions & 13 deletions nemo_automodel/components/distributed/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def parallelize(
dp_replicate_mesh_name: str = "dp_replicate",
dp_shard_cp_mesh_name: str = "dp_shard_cp",
tp_mesh_name: str = "tp",
reshard_after_forward: Optional[bool] = None,
**kwargs,
) -> nn.Module:
"""Apply parallelization strategy to the model."""
Expand Down Expand Up @@ -249,6 +250,12 @@ def parallelize(
# Set FSDP sharding mesh to context parallel mesh if CP > 1, else default to the data parallel mesh.
# if dp_replicate_size > 1, use HSDP, else use FSDP
dp_mesh = get_fsdp_dp_mesh(device_mesh, dp_replicate_mesh_name, dp_shard_cp_mesh_name)
pp_enabled = "pp" in dp_mesh.mesh_dim_names and dp_mesh["pp"].size() > 1
if pp_enabled and reshard_after_forward is True:
logger.warning(
"reshard_after_forward=True overrides the pipeline-parallel default of keeping layer weights "
"gathered across microbatches. This may increase per-microbatch all-gathers and reduce throughput."
)

# Extract layers from the model for parallelization
layers = _extract_model_layers(model)
Expand Down Expand Up @@ -413,6 +420,7 @@ def parallelize(
dp_replicate_mesh_name: str = "dp_replicate",
dp_shard_cp_mesh_name: str = "dp_shard_cp",
tp_mesh_name: str = "tp",
reshard_after_forward: Optional[bool] = None,
**kwargs,
) -> nn.Module:
"""Apply NemotronH-specific parallelization."""
Expand Down Expand Up @@ -490,6 +498,7 @@ def parallelize(
mp_policy=mp_policy,
offload_policy=offload_policy,
fp32_compute_module_names=fp32_compute_module_names,
reshard_after_forward=reshard_after_forward,
)

# do not reshard after forward for root model
Expand Down Expand Up @@ -527,23 +536,66 @@ def parallelize(self, model, device_mesh, dp_shard_cp_mesh_name="dp_shard_cp", *
original_fn = globals().get("apply_fsdp2_sharding_recursively")
assert original_fn is not None, "apply_fsdp2_sharding_recursively not found in module globals"

def _fsdp_by_dtype(module, mesh, mp_policy, offload_policy=None, *args, **kwargs):
def _fsdp_by_dtype(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM — clean PR. The fp32 holder pattern, reshard threading, and state-dict adapter routing all look correct. Good test coverage across models and distributed paths.

module,
mesh,
mp_policy,
offload_policy=None,
enable_fsdp2_prefetch=True,
fsdp2_backward_prefetch_depth=2,
fsdp2_forward_prefetch_depth=1,
reshard_after_forward=None,
fully_shard_fn=None,
):
del enable_fsdp2_prefetch, fsdp2_backward_prefetch_depth, fsdp2_forward_prefetch_depth, fully_shard_fn
pp_enabled = "pp" in mesh.mesh_dim_names and mesh["pp"].size() > 1

if isinstance(module, (nn.ModuleList, nn.ModuleDict)):
items = module.items() if isinstance(module, nn.ModuleDict) else enumerate(module)
for layer_id, child in items:
if isinstance(child, (nn.ModuleList, nn.ModuleDict)):
_fsdp_by_dtype(child, mesh, mp_policy, offload_policy)
all_items = list(module.items()) if isinstance(module, nn.ModuleDict) else list(enumerate(module))
flat_layer_items = [
(layer_id, child)
for layer_id, child in all_items
if not isinstance(child, (nn.ModuleList, nn.ModuleDict))
]
nested_items = [
(layer_id, child)
for layer_id, child in all_items
if isinstance(child, (nn.ModuleList, nn.ModuleDict))
]

for _, child in nested_items:
_fsdp_by_dtype(
child,
mesh,
mp_policy,
offload_policy,
reshard_after_forward=reshard_after_forward,
)

for enum_id, (_, child) in enumerate(flat_layer_items):
if reshard_after_forward is not None:
layer_reshard_after_forward = reshard_after_forward
elif pp_enabled:
layer_reshard_after_forward = False
else:
parallelizer_utils.fully_shard_by_dtype(
child,
mesh,
mp_policy,
offload_policy,
fp32_compute_module_names=fp32_compute_module_names,
)
layer_reshard_after_forward = enum_id < len(flat_layer_items) - 1
parallelizer_utils.fully_shard_by_dtype(
child,
mesh,
mp_policy,
offload_policy,
fp32_compute_module_names=fp32_compute_module_names,
reshard_after_forward=layer_reshard_after_forward,
)
else:
for _, sub in module.named_children():
_fsdp_by_dtype(sub, mesh, mp_policy, offload_policy)
_fsdp_by_dtype(
sub,
mesh,
mp_policy,
offload_policy,
reshard_after_forward=reshard_after_forward,
)

globals()["apply_fsdp2_sharding_recursively"] = _fsdp_by_dtype
try:
Expand Down
18 changes: 16 additions & 2 deletions nemo_automodel/components/distributed/parallelizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,20 @@ def _fully_shard(
mesh: DeviceMesh,
mp_policy: Optional[MixedPrecisionPolicy],
offload_policy: Optional[OffloadPolicy],
reshard_after_forward: Optional[bool] = None,
) -> None:
if isinstance(module, nn.ModuleList):
for layer in module:
_fully_shard(layer, mesh, mp_policy, offload_policy)
_fully_shard(layer, mesh, mp_policy, offload_policy, reshard_after_forward)
else:
fully_shard(module, mesh=mesh, mp_policy=mp_policy, offload_policy=offload_policy)
kwargs = {
"mesh": mesh,
"mp_policy": mp_policy,
"offload_policy": offload_policy,
}
if reshard_after_forward is not None:
kwargs["reshard_after_forward"] = reshard_after_forward
fully_shard(module, **kwargs)


def _mp_policy_with_param_dtype(
Expand Down Expand Up @@ -206,6 +214,7 @@ def fully_shard_by_dtype(
mp_policy: Optional[MixedPrecisionPolicy],
offload_policy: Optional[OffloadPolicy],
fp32_compute_module_names: Tuple[str, ...] = (),
reshard_after_forward: Optional[bool] = None,
) -> None:
"""Fully shard a module so every parameter computes in its required dtype.

Expand All @@ -230,6 +239,8 @@ def fully_shard_by_dtype(
fp32_compute_module_names: Parameter/buffer name substrings that must compute in
fp32 (e.g. ``("_fp32_params",)`` for Qwen3.5's GatedDeltaNet fp32 holder).
Sourced from the model's ``_keep_in_fp32_modules_strict``.
reshard_after_forward: Optional FSDP2 reshard override for this module.
``None`` leaves the caller's default FSDP2 behavior unchanged.
"""
compute_dtype_of = _make_compute_dtype_fn(module, mp_policy, fp32_compute_module_names)

Expand All @@ -254,6 +265,7 @@ def fully_shard_by_dtype(
mesh=mesh,
mp_policy=_mp_policy_with_param_dtype(mp_policy, key[1]),
offload_policy=offload_policy,
reshard_after_forward=reshard_after_forward,
)
else:
least_items_key = min(grouped_params.items(), key=lambda x: len(x[1]))[0]
Expand All @@ -269,6 +281,7 @@ def fully_shard_by_dtype(
mesh=mesh,
mp_policy=_mp_policy_with_param_dtype(mp_policy, key[1]),
offload_policy=offload_policy,
reshard_after_forward=reshard_after_forward,
)
if len(grouped_params) == 2:
parent_key = next(key for key in grouped_params if key != least_items_key)
Expand All @@ -277,4 +290,5 @@ def fully_shard_by_dtype(
mesh=mesh,
mp_policy=_mp_policy_with_param_dtype(mp_policy, parent_key[1]),
offload_policy=offload_policy,
reshard_after_forward=reshard_after_forward,
)
14 changes: 14 additions & 0 deletions nemo_automodel/components/models/common/gated_delta_net_fp32.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,17 @@ def upcast_gated_delta_net_fp32_state_tensor(
if callable(is_floating_point) and is_floating_point():
return tensor.to(dtype=torch.float32)
return tensor


def forced_gated_delta_net_fp32_dtype_mapping(
state_dict: dict[str, object], param_names: tuple[str, ...] = FP32_GDN_PARAM_NAMES
) -> dict[str, str]:
"""Return HF export dtype overrides for intrinsically-fp32 GDN tensors."""
forced: dict[str, str] = {}
for key, tensor in state_dict.items():
if not is_gated_delta_net_fp32_param_key(key, param_names):
continue
is_floating_point = getattr(tensor, "is_floating_point", None)
if callable(is_floating_point) and is_floating_point():
forced[key] = "F32"
return forced
9 changes: 8 additions & 1 deletion nemo_automodel/components/models/deepseek_v4/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.tensor import DTensor

from nemo_automodel.components.models.common import (
BackendConfig,
Expand All @@ -80,6 +81,12 @@
)


def _full_tensor_if_dtensor(tensor: torch.Tensor) -> torch.Tensor:
if isinstance(tensor, DTensor):
tensor = tensor.full_tensor()
return tensor.clone()


def _dsv4_kernel_backend(backend: BackendConfig) -> str:
"""Use TileLang DSV4 sparse kernels only when the attention backend requests them."""
return "tilelang" if backend.attn == "tilelang" else "torch"
Expand Down Expand Up @@ -623,7 +630,7 @@ def __init__(self, value: torch.Tensor):

def forward(self, reference: torch.Tensor | None = None) -> torch.Tensor:
del reference
return self.weight
return _full_tensor_if_dtensor(self.weight)


class DeepseekV4Indexer(nn.Module):
Expand Down
2 changes: 2 additions & 0 deletions nemo_automodel/components/models/hy_mt2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ class HyMT2ForCausalLM(HFCheckpointingMixin, nn.Module, MoEFSDPSyncMixin):
``from_pretrained`` / ``save_pretrained`` over the HF safetensors layout.
"""

_keep_in_fp32_modules_strict = ["mlp.gate.e_score_correction_bias"]

@dataclass(frozen=True)
class ModelCapabilities:
"""Declared parallelism capabilities for this model class."""
Expand Down
2 changes: 2 additions & 0 deletions nemo_automodel/components/models/hy_v3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None:


class HYV3ForCausalLM(HFCheckpointingMixin, nn.Module, MoEFSDPSyncMixin):
_keep_in_fp32_modules_strict = ["mlp.gate.e_score_correction_bias"]

@dataclass(frozen=True)
class ModelCapabilities:
"""Declared parallelism capabilities for this model class."""
Expand Down
2 changes: 2 additions & 0 deletions nemo_automodel/components/models/minimax_m2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None:


class MiniMaxM2ForCausalLM(HFCheckpointingMixin, nn.Module, MoEFSDPSyncMixin):
_keep_in_fp32_modules_strict = ["mlp.gate.e_score_correction_bias"]

@dataclass(frozen=True)
class ModelCapabilities:
"""Declared parallelism capabilities for this model class."""
Expand Down
3 changes: 3 additions & 0 deletions nemo_automodel/components/models/minimax_m3_vl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ def mtp_logits(
class MiniMaxM3SparseForCausalLM(HFCheckpointingMixin, nn.Module, MoEFSDPSyncMixin):
"""Standalone M3 text backbone for causal LM (Stage 1 parity target)."""

_keep_in_fp32_modules_strict = ["mlp.gate.e_score_correction_bias"]

# The state-dict adapter loads every tensor from the checkpoint, so skip HF
# random init on load (also avoids DTensor-collective hangs under sharding/PP).
_skip_init_weights_on_load = True
Expand Down Expand Up @@ -368,6 +370,7 @@ class MiniMaxM3SparseForConditionalGeneration(HFCheckpointingMixin, nn.Module, M
# (vision_encoder.py) fp32 — the bf16 cast would otherwise round it and degrade
# vision RoPE (see llama/rope_utils.py).
_keep_in_fp32_modules = ["rotary_emb", "inv_freq"]
_keep_in_fp32_modules_strict = ["mlp.gate.e_score_correction_bias"]
_pp_keep_self_forward: bool = True
mtp_outputs_are_logits = True
# The state-dict adapter fully populates every tensor from the checkpoint
Expand Down
Loading
Loading