-
Notifications
You must be signed in to change notification settings - Fork 189
fix(gemma4_moe): re-tie lm_head to active embed_tokens on MoE path #2601
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4497260
aa6c324
b5c96b0
bed4c65
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -868,6 +868,15 @@ def __init__( | |
| # Expose moe_config for the MoE parallelizer assertion | ||
| self.model.moe_config = self.model.language_model.moe_config | ||
|
|
||
| # HF's super().__init__() tied lm_head.weight to the *original* text | ||
| # embed_tokens, but the language_model replacement above swapped in a | ||
| # fresh embed_tokens and orphaned that alias. Re-tie through our own | ||
| # tie_weights() override so the public hook (also invoked by AutoModel | ||
| # and checkpoint load via ensure_tied_lm_head) re-points to the active | ||
| # MoE embedding. The shared Parameter survives the in-place cast in | ||
| # initialize_weights(). | ||
| self.tie_weights() | ||
|
|
||
| self.vocab_size = text_config.vocab_size | ||
| # State dict adapter for HF ↔ NeMo weight conversion | ||
| if self.backend.enable_hf_state_dict_adapter: | ||
|
|
@@ -878,6 +887,29 @@ def __init__( | |
| dtype=get_dtype(getattr(text_config, "torch_dtype", None), torch.bfloat16), | ||
| ) | ||
|
|
||
| def tie_weights(self, *_args: object, **_kwargs: object) -> None: | ||
| """Tie ``lm_head`` to the active text ``embed_tokens`` when requested. | ||
|
|
||
| Overrides HF's generic tying so that any caller after the MoE | ||
| ``language_model`` swap (construction, AutoModel, and checkpoint load | ||
| via ``ensure_tied_lm_head``) re-points ``lm_head`` to the *active* | ||
| embedding rather than whatever HF's ``get_input_embeddings()`` | ||
| indirection resolves to. No-op when the config requests untied | ||
| embeddings. | ||
|
|
||
| Accepts and ignores positional/keyword arguments (e.g. HF v5's | ||
| ``recompute_mapping``) so it stays drop-in compatible with the HF | ||
| ``init_weights() -> tie_weights(...)`` call path. | ||
|
|
||
| The controlling flag is the top-level ``Gemma4Config.tie_word_embeddings`` | ||
| (verified against HF: the top-level flag decides tying regardless of the | ||
| nested ``text_config`` value), so read it first and only fall back to | ||
| ``text_config`` for configs that don't expose a top-level flag. | ||
| """ | ||
| text_config = self.config.text_config if hasattr(self.config, "text_config") else self.config | ||
| if getattr(self.config, "tie_word_embeddings", getattr(text_config, "tie_word_embeddings", False)): | ||
| self.lm_head.weight = self.model.language_model.embed_tokens.weight | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One small correction after our discussion in the issue and checking Gemma4 HF behavior directly: for I tested the conflicting cases under transformers 5.8.1:
So the current Gemma4 MoE fix is correct for the normal/default case because both flags are I'd suggest changing the tie check to top-level-first, e.g.: text_config = self.config.text_config if hasattr(self.config, "text_config") else self.config
if getattr(self.config, "tie_word_embeddings", getattr(text_config, "tie_word_embeddings", False)):
self.lm_head.weight = self.model.language_model.embed_tokens.weightAnd in the tests, set the top-level flag explicitly too: config = Gemma4Config(
text_config=_make_text_config(tie_word_embeddings=tie_word_embeddings),
tie_word_embeddings=tie_word_embeddings,
)That keeps the fix aligned with HF's actual tying behavior instead of baking in the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh I see , good catch , for the default case both flags are True so the re-tie was correct, but you're right it should follow the top-level |
||
|
|
||
| def forward( | ||
| self, | ||
| input_ids: torch.Tensor | None = None, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,144 @@ | ||
| # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Weight-tying tests for the Gemma4 MoE conditional-generation model. | ||
|
|
||
| HF's ``super().__init__()`` ties ``lm_head.weight`` to the original text | ||
| ``embed_tokens``. The MoE path then replaces ``language_model`` with | ||
| ``Gemma4MoETextModelBackend`` (a fresh ``embed_tokens``), which orphans that | ||
| alias. The model re-ties ``lm_head`` to the now-active embedding when | ||
| ``tie_word_embeddings`` is set (Gemma defaults to ``True``); these tests pin | ||
| that behavior for both the tied and untied configs. | ||
|
|
||
| Runs on CPU (no CUDA / TE / DeepEP required). | ||
| """ | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from transformers.models.gemma4.configuration_gemma4 import Gemma4Config, Gemma4TextConfig | ||
|
|
||
| from nemo_automodel.components.models.common import BackendConfig | ||
| from nemo_automodel.components.models.gemma4_moe.model import ( | ||
| Gemma4ForConditionalGeneration, | ||
| Gemma4MoETextModelBackend, | ||
| ) | ||
|
|
||
|
|
||
| def _make_text_config(**overrides): | ||
| """Tiny Gemma4TextConfig (2 layers, small hidden, tiny vocab, few experts).""" | ||
| defaults = dict( | ||
| vocab_size=256, | ||
| hidden_size=64, | ||
| num_attention_heads=4, | ||
| num_key_value_heads=2, | ||
| head_dim=16, | ||
| num_hidden_layers=2, | ||
| intermediate_size=128, | ||
| rms_norm_eps=1e-6, | ||
| max_position_embeddings=256, | ||
| enable_moe_block=True, # routes construction through the NeMo MoE backend | ||
| num_experts=4, | ||
| top_k_experts=2, | ||
| moe_intermediate_size=64, | ||
| layer_types=["full_attention", "sliding_attention"], | ||
| sliding_window=128, | ||
| hidden_activation="gelu_pytorch_tanh", | ||
| torch_dtype="bfloat16", | ||
| ) | ||
| defaults.update(overrides) | ||
| return Gemma4TextConfig(**defaults) | ||
|
|
||
|
|
||
| def _make_cpu_backend(): | ||
| """CPU-friendly backend: no TE, no DeepEP, plain torch kernels.""" | ||
| return BackendConfig( | ||
| linear="torch", | ||
| attn="sdpa", | ||
| rms_norm="torch", | ||
| experts="torch", | ||
| dispatcher="torch", | ||
| fake_balanced_gate=False, | ||
| enable_hf_state_dict_adapter=False, | ||
| ) | ||
|
|
||
|
|
||
| def _build(tie_word_embeddings: bool, text_tie: bool | None = None) -> Gemma4ForConditionalGeneration: | ||
| # The controlling flag is the top-level Gemma4Config.tie_word_embeddings (matches HF); | ||
| # text_tie lets a test set a conflicting nested value to prove top-level wins. | ||
| config = Gemma4Config( | ||
| text_config=_make_text_config(tie_word_embeddings=tie_word_embeddings if text_tie is None else text_tie), | ||
| tie_word_embeddings=tie_word_embeddings, | ||
| ) | ||
| model = Gemma4ForConditionalGeneration(config, backend=_make_cpu_backend()) | ||
| # Sanity: construction routed through the real NeMo MoE backend (the path | ||
| # that replaces language_model and breaks HF's tie). | ||
| assert isinstance(model.model.language_model, Gemma4MoETextModelBackend) | ||
| return model | ||
|
|
||
|
|
||
| def test_tied_lm_head_shares_active_embedding_after_construction(): | ||
| """tie_word_embeddings=True: lm_head must alias the *active* MoE embed_tokens.""" | ||
| model = _build(tie_word_embeddings=True) | ||
| assert model.lm_head.weight is model.model.language_model.embed_tokens.weight | ||
|
|
||
|
|
||
| def test_tied_lm_head_survives_initialize_weights(): | ||
| """The tie set in __init__ must survive the bf16 cast in initialize_weights().""" | ||
| model = _build(tie_word_embeddings=True) | ||
| model.initialize_weights(dtype=torch.bfloat16, buffer_device=torch.device("cpu")) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| embed = model.model.language_model.embed_tokens.weight | ||
| lm_head = model.lm_head.weight | ||
| assert lm_head is embed | ||
| assert lm_head.dtype == torch.bfloat16 | ||
|
|
||
|
|
||
| def test_untied_lm_head_is_separate(): | ||
| """tie_word_embeddings=False: lm_head must keep its own storage.""" | ||
| model = _build(tie_word_embeddings=False) | ||
| assert model.lm_head.weight is not model.model.language_model.embed_tokens.weight | ||
| assert model.lm_head.weight.data_ptr() != model.model.language_model.embed_tokens.weight.data_ptr() | ||
|
|
||
|
|
||
| def test_tie_weights_hook_reties_to_active_embedding(): | ||
| """The public tie_weights() hook must re-point lm_head at the active MoE embedding. | ||
|
|
||
| Checkpoint/AutoModel paths call ``model.tie_weights()`` after construction | ||
| (e.g. via ``ensure_tied_lm_head``); this guards that the override re-ties to | ||
| the live embedding rather than HF's generic behavior. | ||
| """ | ||
| model = _build(tie_word_embeddings=True) | ||
| # Break the tie with a fresh, independent parameter. | ||
| model.lm_head.weight = nn.Parameter(model.lm_head.weight.detach().clone()) | ||
| assert model.lm_head.weight is not model.model.language_model.embed_tokens.weight | ||
|
|
||
| model.tie_weights() | ||
| assert model.lm_head.weight is model.model.language_model.embed_tokens.weight | ||
|
|
||
|
|
||
| def test_tie_weights_hook_is_noop_when_untied(): | ||
| """tie_weights() must not alias storage when the config requests untied embeddings.""" | ||
| model = _build(tie_word_embeddings=False) | ||
| model.tie_weights() | ||
| assert model.lm_head.weight is not model.model.language_model.embed_tokens.weight | ||
|
|
||
|
|
||
| def test_top_level_flag_controls_tie_when_flags_disagree(): | ||
| """The controlling flag is top-level Gemma4Config.tie_word_embeddings, not text_config (matches HF).""" | ||
| # top-level True wins over text_config False -> tied | ||
| tied = _build(tie_word_embeddings=True, text_tie=False) | ||
| assert tied.lm_head.weight is tied.model.language_model.embed_tokens.weight | ||
| # top-level False wins over text_config True -> untied | ||
| untied = _build(tie_word_embeddings=False, text_tie=True) | ||
| assert untied.lm_head.weight is not untied.model.language_model.embed_tokens.weight | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: every other model that implements
tie_weights()also declares_tied_weights_keys(e.g. Llama, Qwen2, Mistral3, Baichuan all set_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}). The checkpoint utilities have fallbacks that cover this (including the hardcoded"model.language_model.embed_tokens.weight"candidate in_candidate_source_names()), so this isn't a correctness bug today — but adding the declaration would be consistent with the rest of the codebase and the onboarding checklist:(as a class attribute on
Gemma4ForConditionalGeneration). Pre-existing gap, so fine to address in a follow-up.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_tied_weights_keysis already inherited fromHFGemma4ForConditionalGenerationas{'lm_head.weight': 'model.language_model.embed_tokens.weight'}— the mapping you suggested. Checkpoint utils read it via the dict branch inget_tied_lm_head_source_names, so the tied source resolves through the declaration rather than a fallback.Llama/Qwen2 declare it explicitly because they're NeMo-native classes with no HF parent. Gemma4 subclasses the HF class that already sets it, so a re-declaration would be redundant. Happy to add an explicit class attribute for visibility if a maintainer prefers, but it's not needed for correctness.