Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions nemo_automodel/components/models/gemma4_moe/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,15 @@ def __init__(
# Expose moe_config for the MoE parallelizer assertion
self.model.moe_config = self.model.language_model.moe_config

# HF's super().__init__() tied lm_head.weight to the *original* text
# embed_tokens, but the language_model replacement above swapped in a
# fresh embed_tokens and orphaned that alias. Re-tie through our own
# tie_weights() override so the public hook (also invoked by AutoModel
# and checkpoint load via ensure_tied_lm_head) re-points to the active
# MoE embedding. The shared Parameter survives the in-place cast in
# initialize_weights().
self.tie_weights()

self.vocab_size = text_config.vocab_size
# State dict adapter for HF ↔ NeMo weight conversion
if self.backend.enable_hf_state_dict_adapter:
Expand All @@ -878,6 +887,29 @@ def __init__(
dtype=get_dtype(getattr(text_config, "torch_dtype", None), torch.bfloat16),
)

def tie_weights(self, *_args: object, **_kwargs: object) -> None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: every other model that implements tie_weights() also declares _tied_weights_keys (e.g. Llama, Qwen2, Mistral3, Baichuan all set _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}). The checkpoint utilities have fallbacks that cover this (including the hardcoded "model.language_model.embed_tokens.weight" candidate in _candidate_source_names()), so this isn't a correctness bug today — but adding the declaration would be consistent with the rest of the codebase and the onboarding checklist:

_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}

(as a class attribute on Gemma4ForConditionalGeneration). Pre-existing gap, so fine to address in a follow-up.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_tied_weights_keys is already inherited from HFGemma4ForConditionalGeneration as {'lm_head.weight': 'model.language_model.embed_tokens.weight'} — the mapping you suggested. Checkpoint utils read it via the dict branch in get_tied_lm_head_source_names, so the tied source resolves through the declaration rather than a fallback.

Llama/Qwen2 declare it explicitly because they're NeMo-native classes with no HF parent. Gemma4 subclasses the HF class that already sets it, so a re-declaration would be redundant. Happy to add an explicit class attribute for visibility if a maintainer prefers, but it's not needed for correctness.

"""Tie ``lm_head`` to the active text ``embed_tokens`` when requested.

Overrides HF's generic tying so that any caller after the MoE
``language_model`` swap (construction, AutoModel, and checkpoint load
via ``ensure_tied_lm_head``) re-points ``lm_head`` to the *active*
embedding rather than whatever HF's ``get_input_embeddings()``
indirection resolves to. No-op when the config requests untied
embeddings.

Accepts and ignores positional/keyword arguments (e.g. HF v5's
``recompute_mapping``) so it stays drop-in compatible with the HF
``init_weights() -> tie_weights(...)`` call path.

The controlling flag is the top-level ``Gemma4Config.tie_word_embeddings``
(verified against HF: the top-level flag decides tying regardless of the
nested ``text_config`` value), so read it first and only fall back to
``text_config`` for configs that don't expose a top-level flag.
"""
text_config = self.config.text_config if hasattr(self.config, "text_config") else self.config
if getattr(self.config, "tie_word_embeddings", getattr(text_config, "tie_word_embeddings", False)):
self.lm_head.weight = self.model.language_model.embed_tokens.weight

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One small correction after our discussion in the issue and checking Gemma4 HF behavior directly: for Gemma4ForConditionalGeneration, the controlling flag is the top-level Gemma4Config.tie_word_embeddings, not text_config.tie_word_embeddings.

I tested the conflicting cases under transformers 5.8.1:

  • top-level True, text True -> tied
  • top-level False, text False -> untied
  • top-level False, text True -> untied
  • top-level True, text False -> tied

So the current Gemma4 MoE fix is correct for the normal/default case because both flags are True, but the override should not check text_config first. Otherwise it can diverge from HF behavior when the two flags disagree.

I'd suggest changing the tie check to top-level-first, e.g.:

text_config = self.config.text_config if hasattr(self.config, "text_config") else self.config
if getattr(self.config, "tie_word_embeddings", getattr(text_config, "tie_word_embeddings", False)):
    self.lm_head.weight = self.model.language_model.embed_tokens.weight

And in the tests, set the top-level flag explicitly too:

config = Gemma4Config(
    text_config=_make_text_config(tie_word_embeddings=tie_word_embeddings),
    tie_word_embeddings=tie_word_embeddings,
)

That keeps the fix aligned with HF's actual tying behavior instead of baking in the text_config-first rule.

@Achyuthan-S Achyuthan-S Jun 17, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see , good catch , for the default case both flags are True so the re-tie was correct, but you're right it should follow the top-level Gemma4Config.tie_word_embeddings. Confirmed by construction on transformers 5.8.1 that all four combinations match your table. Switched the override to top-level-first (with a text_config fallback), set the top-level flag explicitly in the tests, and added a case with the flags disagreeing (top True/text False → tied, top False/text True → untied) to pin it. pytest tests/unit_tests/models/gemma4_moe/ -v → 7 passed.


def forward(
self,
input_ids: torch.Tensor | None = None,
Expand Down
144 changes: 144 additions & 0 deletions tests/unit_tests/models/gemma4_moe/test_gemma4_moe_tied_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Weight-tying tests for the Gemma4 MoE conditional-generation model.

HF's ``super().__init__()`` ties ``lm_head.weight`` to the original text
``embed_tokens``. The MoE path then replaces ``language_model`` with
``Gemma4MoETextModelBackend`` (a fresh ``embed_tokens``), which orphans that
alias. The model re-ties ``lm_head`` to the now-active embedding when
``tie_word_embeddings`` is set (Gemma defaults to ``True``); these tests pin
that behavior for both the tied and untied configs.

Runs on CPU (no CUDA / TE / DeepEP required).
"""

import torch
import torch.nn as nn
from transformers.models.gemma4.configuration_gemma4 import Gemma4Config, Gemma4TextConfig

from nemo_automodel.components.models.common import BackendConfig
from nemo_automodel.components.models.gemma4_moe.model import (
Gemma4ForConditionalGeneration,
Gemma4MoETextModelBackend,
)


def _make_text_config(**overrides):
"""Tiny Gemma4TextConfig (2 layers, small hidden, tiny vocab, few experts)."""
defaults = dict(
vocab_size=256,
hidden_size=64,
num_attention_heads=4,
num_key_value_heads=2,
head_dim=16,
num_hidden_layers=2,
intermediate_size=128,
rms_norm_eps=1e-6,
max_position_embeddings=256,
enable_moe_block=True, # routes construction through the NeMo MoE backend
num_experts=4,
top_k_experts=2,
moe_intermediate_size=64,
layer_types=["full_attention", "sliding_attention"],
sliding_window=128,
hidden_activation="gelu_pytorch_tanh",
torch_dtype="bfloat16",
)
defaults.update(overrides)
return Gemma4TextConfig(**defaults)


def _make_cpu_backend():
"""CPU-friendly backend: no TE, no DeepEP, plain torch kernels."""
return BackendConfig(
linear="torch",
attn="sdpa",
rms_norm="torch",
experts="torch",
dispatcher="torch",
fake_balanced_gate=False,
enable_hf_state_dict_adapter=False,
)


def _build(tie_word_embeddings: bool, text_tie: bool | None = None) -> Gemma4ForConditionalGeneration:
# The controlling flag is the top-level Gemma4Config.tie_word_embeddings (matches HF);
# text_tie lets a test set a conflicting nested value to prove top-level wins.
config = Gemma4Config(
text_config=_make_text_config(tie_word_embeddings=tie_word_embeddings if text_tie is None else text_tie),
tie_word_embeddings=tie_word_embeddings,
)
model = Gemma4ForConditionalGeneration(config, backend=_make_cpu_backend())
# Sanity: construction routed through the real NeMo MoE backend (the path
# that replaces language_model and breaks HF's tie).
assert isinstance(model.model.language_model, Gemma4MoETextModelBackend)
return model


def test_tied_lm_head_shares_active_embedding_after_construction():
"""tie_word_embeddings=True: lm_head must alias the *active* MoE embed_tokens."""
model = _build(tie_word_embeddings=True)
assert model.lm_head.weight is model.model.language_model.embed_tokens.weight


def test_tied_lm_head_survives_initialize_weights():
"""The tie set in __init__ must survive the bf16 cast in initialize_weights()."""
model = _build(tie_word_embeddings=True)
model.initialize_weights(dtype=torch.bfloat16, buffer_device=torch.device("cpu"))

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.device has been a context manager since PyTorch 1.13 — with torch.device(...): sets the default device for factory calls, so with buffer_device: is valid, including on CPU. The test here passes buffer_device=torch.device("cpu") and runs through this exact MoE block; it's green on torch 2.10, so there's no TypeError on this path. initialize_weights() is also pre-existing and unchanged by this PR.


embed = model.model.language_model.embed_tokens.weight
lm_head = model.lm_head.weight
assert lm_head is embed
assert lm_head.dtype == torch.bfloat16


def test_untied_lm_head_is_separate():
"""tie_word_embeddings=False: lm_head must keep its own storage."""
model = _build(tie_word_embeddings=False)
assert model.lm_head.weight is not model.model.language_model.embed_tokens.weight
assert model.lm_head.weight.data_ptr() != model.model.language_model.embed_tokens.weight.data_ptr()


def test_tie_weights_hook_reties_to_active_embedding():
"""The public tie_weights() hook must re-point lm_head at the active MoE embedding.

Checkpoint/AutoModel paths call ``model.tie_weights()`` after construction
(e.g. via ``ensure_tied_lm_head``); this guards that the override re-ties to
the live embedding rather than HF's generic behavior.
"""
model = _build(tie_word_embeddings=True)
# Break the tie with a fresh, independent parameter.
model.lm_head.weight = nn.Parameter(model.lm_head.weight.detach().clone())
assert model.lm_head.weight is not model.model.language_model.embed_tokens.weight

model.tie_weights()
assert model.lm_head.weight is model.model.language_model.embed_tokens.weight


def test_tie_weights_hook_is_noop_when_untied():
"""tie_weights() must not alias storage when the config requests untied embeddings."""
model = _build(tie_word_embeddings=False)
model.tie_weights()
assert model.lm_head.weight is not model.model.language_model.embed_tokens.weight


def test_top_level_flag_controls_tie_when_flags_disagree():
"""The controlling flag is top-level Gemma4Config.tie_word_embeddings, not text_config (matches HF)."""
# top-level True wins over text_config False -> tied
tied = _build(tie_word_embeddings=True, text_tie=False)
assert tied.lm_head.weight is tied.model.language_model.embed_tokens.weight
# top-level False wins over text_config True -> untied
untied = _build(tie_word_embeddings=False, text_tie=True)
assert untied.lm_head.weight is not untied.model.language_model.embed_tokens.weight
Loading