Skip to content

Audit model-specific fp32 protected tensors during dtype casting #2570

Description

@yuhezhang-ai

Problem

The shared dtype-casting path and several concrete fp32-protection bugs are already largely addressed by fixes on main and by #2484.

PR #2484 covers the Qwen GatedDeltaNet path by making A_log / dt_bias explicit fp32 tensors. It also covers the currently known opt/router correction-bias case for Nemotron 3 Super.

The remaining work is an audit task: some other model families may still have numerically sensitive parameters or buffers that should be explicitly marked to stay fp32 during model dtype casting.

Scope

  • Audit model-specific parameters and buffers that should remain fp32 under bf16/fp16 model casts.
  • For each real case found, add the appropriate fp32 tracking marker, such as _keep_in_fp32_modules or _keep_in_fp32_modules_strict.
  • Add focused unit coverage showing the sensitive tensor remains fp32 after cast_model_to_dtype(...).
  • Check both normal construction and relevant sharded/FSDP2 behavior when applicable.

Notes

This issue is specifically for remaining model-by-model audit work after the main dtype fixes and #2484.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions