Skip to content

fix(gemma4_moe): re-tie lm_head to active embed_tokens on MoE path#2601

Open
Achyuthan-S wants to merge 4 commits into
NVIDIA-NeMo:mainfrom
Achyuthan-S:Achyuthan-S/fix/gemma4-moe-tied-lm-head
Open

fix(gemma4_moe): re-tie lm_head to active embed_tokens on MoE path#2601
Achyuthan-S wants to merge 4 commits into
NVIDIA-NeMo:mainfrom
Achyuthan-S:Achyuthan-S/fix/gemma4-moe-tied-lm-head

Conversation

@Achyuthan-S

@Achyuthan-S Achyuthan-S commented Jun 16, 2026

Copy link
Copy Markdown

Picking up the Gemma4 MoE part of #2512.

Problem

Gemma4ForConditionalGeneration lets HF's super().__init__() set up the model, which ties lm_head.weight to the text embed_tokens. The MoE path then replaces self.model.language_model with a fresh Gemma4MoETextModelBackend that has its own embed_tokens, so lm_head is left aliased to the old, orphaned embedding instead of the active one.

Gemma defaults to tie_word_embeddings=True, so this runs with an effectively untied head: lm_head and the live embedding drift apart during training, and the tied-head checkpoint guard from #2511 can't tell they're supposed to be tied (the two tensors don't share storage).

Fix

After the language_model swap, re-point lm_head.weight at the now-active embed_tokens.weight when tie_word_embeddings is set. The shared parameter survives the in-place bf16 cast in initialize_weights(). With the storage genuinely shared again, has_local_tied_lm_head() returns True, so the save path drops lm_head and the load path reconstructs it instead of letting a second copy diverge.

I left the state-dict adapter alone on purpose. Once the head is actually tied, loading a tied HF checkpoint (which only stores embed_tokens.weight) is already covered by the existing ensure_tied_lm_head() on the load path.

Tests

Added tests/unit_tests/models/gemma4_moe/test_gemma4_moe_tied_weights.py (CPU, reuses the tiny-config helpers from the existing rope test):

  • tied config shares lm_head/embed_tokens storage right after construction
  • the tie survives initialize_weights(bf16)
  • untied config keeps separate storage
pytest tests/unit_tests/models/gemma4_moe/ -v
# 4 passed

A full save → load → resume checkpoint test would be a reasonable follow-up. The per-family reject-guard work for the rest of the audit is tracked separately under #2512.

cc @yuhezhang-ai

Copilot AI review requested due to automatic review settings June 16, 2026 19:10
@Achyuthan-S Achyuthan-S requested a review from a team as a code owner June 16, 2026 19:10
@copy-pr-bot

copy-pr-bot Bot commented Jun 16, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Adds regression coverage and a small model fix to ensure Gemma4 MoE preserves Hugging Face-style weight tying between lm_head and the active embedding after the MoE backend swaps out the language model.

Changes:

  • Re-tie lm_head.weight to the MoE backend’s embed_tokens.weight when tie_word_embeddings=True.
  • Add CPU-only unit tests validating tied and untied behavior, including after initialize_weights() casting.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
tests/unit_tests/models/gemma4_moe/test_gemma4_moe_tied_weights.py New unit tests to lock in correct tied/untied lm_head ↔ embedding behavior across initialization.
nemo_automodel/components/models/gemma4_moe/model.py Restores lm_head weight tying after replacing the HF language model with the MoE backend.
Comments suppressed due to low confidence (1)

nemo_automodel/components/models/gemma4_moe/model.py:859

  • buffer_device is a torch.device, but the code uses with buffer_device:, which is invalid (devices are not context managers). This is currently a runtime error on the MoE path when initialize_weights() reaches this block. Replace the context-manager usage with a valid device context (typically torch.cuda.device(index) when CUDA is available) or restructure the code to avoid needing a context manager at all and instead move tensors/ops explicitly to the target device.
            else getattr(

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +818 to +819
if getattr(text_config, "tie_word_embeddings", False):
self.lm_head.weight = self.model.language_model.embed_tokens.weight

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intentional. HF's tie_weights() already ran in super().__init__() and tied lm_head to the original embed_tokens; the issue is that the MoE swap replaces language_model afterward, so re-running tie_weights() leans on the same get_input_embeddings() indirection this is working around (moe/parallelizer.py notes HF's tie_weights() can be incompatible with these custom models). The direct re-point mirrors llama's tie_weights(). Downstream save/load still goes through ensure_tied_lm_head(), which tries tie_weights() first and falls back to direct assignment, so the conventional path is covered.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the direct re-point is the right primitive here, but it would still be better to expose it through this class’s own tie_weights() override.

The reason is that AutoModel/checkpoint paths can call model.tie_weights() again after construction/load. We should make sure that call re-ties to the active MoE embedding, not whatever HF’s generic behavior would do. So the method can still use your exact direct assignment:

def tie_weights(self):
    text_config = self.config.text_config if hasattr(self.config, "text_config") else self.config
    if getattr(text_config, "tie_word_embeddings", False):
        self.lm_head.weight = self.model.language_model.embed_tokens.weight

Then after replacing language_model, call self.tie_weights().

This preserves the intentional direct assignment, but makes the public re-tie hook reliable for downstream load/checkpoint paths.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — pulled the direct re-point into tie_weights() and call it after the language_model swap. Added tests that break the tie and assert tie_weights() re-points to the active MoE embedding (and is a no-op when untied).

def test_tied_lm_head_survives_initialize_weights():
"""The tie set in __init__ must survive the bf16 cast in initialize_weights()."""
model = _build(tie_word_embeddings=True)
model.initialize_weights(dtype=torch.bfloat16, buffer_device=torch.device("cpu"))

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.device has been a context manager since PyTorch 1.13 — with torch.device(...): sets the default device for factory calls, so with buffer_device: is valid, including on CPU. The test here passes buffer_device=torch.device("cpu") and runs through this exact MoE block; it's green on torch 2.10, so there's no TypeError on this path. initialize_weights() is also pre-existing and unchanged by this PR.

The MoE path replaces language_model after HF __init__, orphaning the lm_head<->embed_tokens tie that HF set up. Re-tie lm_head to the active embed_tokens when tie_word_embeddings is set (Gemma defaults True). Add CPU tied/untied tests.

Refs NVIDIA-NeMo#2512

Signed-off-by: Achyuthan-S <as21154@nyu.edu>
@Achyuthan-S Achyuthan-S force-pushed the Achyuthan-S/fix/gemma4-moe-tied-lm-head branch from 46282b9 to 4497260 Compare June 16, 2026 19:14
Achyuthan-S and others added 2 commits June 16, 2026 23:52
Wrap the post-swap lm_head re-point in a tie_weights() override so
AutoModel and ensure_tied_lm_head() re-tie to the active MoE embedding.
Add hook tests for re-tie and untied no-op.

Signed-off-by: Achyuthan Sivasankar <achyuthan.sivasankar@gmail.com>
@athitten

Copy link
Copy Markdown
Contributor

/ok to test b5c96b0

@Achyuthan-S

Copy link
Copy Markdown
Author

Hi @yuhezhang-ai , Thank you very much for your support. Please feel free to assign me to interesting issues/features. would love to contribute more!

@yuhezhang-ai

Copy link
Copy Markdown
Contributor

Hi @yuhezhang-ai , Thank you very much for your support. Please feel free to assign me to interesting issues/features. would love to contribute more!

Thank you, and thanks again for the quick fix here. The PR looks good to me.
If you are interested, please feel free to continue with the broader tied-embedding audit we discussed earlier.

@Achyuthan-S

Copy link
Copy Markdown
Author

Hi @yuhezhang-ai , Thank you very much for your support. Please feel free to assign me to interesting issues/features. would love to contribute more!

Thank you, and thanks again for the quick fix here. The PR looks good to me. If you are interested, please feel free to continue with the broader tied-embedding audit we discussed earlier.

Thanks @yuhezhang-ai — glad the Gemma4 piece looks good.

I'll pick up the broader audit next: confirm HF defaults / checkpoint behavior per family, then land the reject-when-tie=True guard for the separate-head models (as we scoped on #2512). I'll post the per-family findings on the issue before opening a broad PR.

@athitten

Copy link
Copy Markdown
Contributor

/claude review

@athitten

Copy link
Copy Markdown
Contributor

/ok to test b5c96b0

dtype=get_dtype(getattr(text_config, "torch_dtype", None), torch.bfloat16),
)

def tie_weights(self, *_args: object, **_kwargs: object) -> None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: every other model that implements tie_weights() also declares _tied_weights_keys (e.g. Llama, Qwen2, Mistral3, Baichuan all set _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}). The checkpoint utilities have fallbacks that cover this (including the hardcoded "model.language_model.embed_tokens.weight" candidate in _candidate_source_names()), so this isn't a correctness bug today — but adding the declaration would be consistent with the rest of the codebase and the onboarding checklist:

_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}

(as a class attribute on Gemma4ForConditionalGeneration). Pre-existing gap, so fine to address in a follow-up.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_tied_weights_keys is already inherited from HFGemma4ForConditionalGeneration as {'lm_head.weight': 'model.language_model.embed_tokens.weight'} — the mapping you suggested. Checkpoint utils read it via the dict branch in get_tied_lm_head_source_names, so the tied source resolves through the declaration rather than a fallback.

Llama/Qwen2 declare it explicitly because they're NeMo-native classes with no HF parent. Gemma4 subclasses the HF class that already sets it, so a re-declaration would be redundant. Happy to add an explicit class attribute for visibility if a maintainer prefers, but it's not needed for correctness.

"""
text_config = self.config.text_config if hasattr(self.config, "text_config") else self.config
if getattr(text_config, "tie_word_embeddings", False):
self.lm_head.weight = self.model.language_model.embed_tokens.weight

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One small correction after our discussion in the issue and checking Gemma4 HF behavior directly: for Gemma4ForConditionalGeneration, the controlling flag is the top-level Gemma4Config.tie_word_embeddings, not text_config.tie_word_embeddings.

I tested the conflicting cases under transformers 5.8.1:

  • top-level True, text True -> tied
  • top-level False, text False -> untied
  • top-level False, text True -> untied
  • top-level True, text False -> tied

So the current Gemma4 MoE fix is correct for the normal/default case because both flags are True, but the override should not check text_config first. Otherwise it can diverge from HF behavior when the two flags disagree.

I'd suggest changing the tie check to top-level-first, e.g.:

text_config = self.config.text_config if hasattr(self.config, "text_config") else self.config
if getattr(self.config, "tie_word_embeddings", getattr(text_config, "tie_word_embeddings", False)):
    self.lm_head.weight = self.model.language_model.embed_tokens.weight

And in the tests, set the top-level flag explicitly too:

config = Gemma4Config(
    text_config=_make_text_config(tie_word_embeddings=tie_word_embeddings),
    tie_word_embeddings=tie_word_embeddings,
)

That keeps the fix aligned with HF's actual tying behavior instead of baking in the text_config-first rule.

@Achyuthan-S Achyuthan-S Jun 17, 2026

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see , good catch , for the default case both flags are True so the re-tie was correct, but you're right it should follow the top-level Gemma4Config.tie_word_embeddings. Confirmed by construction on transformers 5.8.1 that all four combinations match your table. Switched the override to top-level-first (with a text_config fallback), set the top-level flag explicitly in the tests, and added a case with the flags disagreeing (top True/text False → tied, top False/text True → untied) to pin it. pytest tests/unit_tests/models/gemma4_moe/ -v → 7 passed.

Per review: HF Gemma4 ties on the top-level Gemma4Config.tie_word_embeddings regardless of the nested text_config (verified by construction on transformers 5.8.1). Read top-level first with a text_config fallback, set the top-level flag in tests, and add a case proving top-level wins when the two disagree.
Refs NVIDIA-NeMo#2512

Signed-off-by: Achyuthan Sivasankar <achyuthan.sivasankar@gmail.com>
@yuhezhang-ai

Copy link
Copy Markdown
Contributor

/ok to test bed4c65

@akoumpa

akoumpa commented Jun 18, 2026

Copy link
Copy Markdown
Contributor

/claude review

@svcnvidia-nemo-ci svcnvidia-nemo-ci added the waiting-on-maintainers Waiting on maintainers to respond label Jun 20, 2026
@Achyuthan-S

Copy link
Copy Markdown
Author

Thanks again for the review and approval, @yuhezhang-ai! Is anything else needed from my side before it can be merged, or is it just waiting on a merge window? Happy to rebase if that helps.

@svcnvidia-nemo-ci svcnvidia-nemo-ci removed the waiting-on-maintainers Waiting on maintainers to respond label Jun 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants