Skip to content

Commit 41286bf

Browse files
wesleytruongwwwjn
authored andcommitted
[HF] Llama4 Text State Dict Adapter (#1662)
1 parent 10a1c54 commit 41286bf

3 files changed

Lines changed: 143 additions & 0 deletions

File tree

torchtitan/experiments/llama4/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .infra.parallelize import parallelize_llama
1717
from .model.args import TransformerModelArgs
1818
from .model.model import Transformer
19+
from .model.state_dict_adapter import Llama4StateDictAdapter
1920

2021
__all__ = [
2122
"TransformerModelArgs",
@@ -108,5 +109,6 @@
108109
build_dataloader_fn=build_hf_dataloader,
109110
build_tokenizer_fn=build_hf_tokenizer,
110111
build_loss_fn=build_cross_entropy_loss,
112+
state_dict_adapter=Llama4StateDictAdapter,
111113
)
112114
)
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
import re
9+
from collections import defaultdict
10+
from typing import Any
11+
12+
import torch
13+
14+
logger = logging.getLogger()
15+
16+
from torchtitan.protocols.state_dict_adapter import StateDictAdapter
17+
18+
from .args import TransformerModelArgs
19+
20+
21+
class Llama4StateDictAdapter(StateDictAdapter):
22+
def __init__(self, model_args: TransformerModelArgs, hf_assets_path: str | None):
23+
super().__init__(model_args, hf_assets_path)
24+
25+
self.model_args = model_args
26+
self.hf_assets_path = hf_assets_path
27+
self.from_hf_map = {
28+
"language_model.model.embed_tokens.weight": "tok_embeddings.weight",
29+
"language_model.model.norm.weight": "norm.weight",
30+
"language_model.lm_head.weight": "output.weight",
31+
"language_model.model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
32+
"language_model.model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
33+
"language_model.model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
34+
"language_model.model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
35+
"language_model.model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
36+
"language_model.model.layers.{}.feed_forward.router.weight": "layers.{}.moe.router.gate.weight",
37+
"language_model.model.layers.{}.feed_forward.experts.down_proj": "layers.{}.moe.experts.w2",
38+
None: "layers.{}.moe.expert_bias",
39+
"language_model.model.layers.{}.feed_forward.shared_expert.gate_proj.weight": "layers.{}.moe.shared_experts.w1.weight",
40+
"language_model.model.layers.{}.feed_forward.shared_expert.down_proj.weight": "layers.{}.moe.shared_experts.w2.weight",
41+
"language_model.model.layers.{}.feed_forward.shared_expert.up_proj.weight": "layers.{}.moe.shared_experts.w3.weight",
42+
"language_model.model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
43+
}
44+
45+
def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
46+
to_hf_map = {v: k for k, v in self.from_hf_map.items()}
47+
48+
hf_state_dict = {}
49+
50+
# Keeps track of TT fqn values to combine into one HF fqn later
51+
# {hf_fqn : {tt_fqn1 : value}, {tt_fqn2 : value}, ...}
52+
to_combine = defaultdict(dict)
53+
for key, value in state_dict.items():
54+
if "layers" in key:
55+
layer_num = re.search(r"\d+", key).group(0)
56+
key = re.sub(r"(\d+)", "{}", key, count=1)
57+
else:
58+
layer_num = None
59+
60+
if key in to_hf_map:
61+
# do direct mapping
62+
if key in "layers.{}.moe.experts.w2":
63+
# we transpose the expert weights for torchtitan optimization purpose
64+
value = value.transpose(-1, -2)
65+
66+
new_key = to_hf_map[key]
67+
if new_key is None:
68+
continue
69+
if layer_num:
70+
new_key = new_key.format(layer_num)
71+
hf_state_dict[new_key] = value
72+
elif key in [
73+
"layers.{}.moe.experts.w1",
74+
"layers.{}.moe.experts.w3",
75+
]:
76+
# handle collecting values to combine
77+
hf_abstract_key = (
78+
"language_model.model.layers.{}.feed_forward.experts.gate_up_proj"
79+
)
80+
if hf_abstract_key is None:
81+
continue
82+
to_combine[hf_abstract_key.format(layer_num)][
83+
key.format(layer_num)
84+
] = value
85+
86+
# combine collected values
87+
for hf_fqn, tt_fqn_map in to_combine.items():
88+
layer_num = re.search(r"\d+", hf_fqn).group(0)
89+
combine_values = []
90+
# put into correct order to combine
91+
for tt_abstract_key in [
92+
"layers.{}.moe.experts.w1",
93+
"layers.{}.moe.experts.w3",
94+
]:
95+
tt_key = tt_abstract_key.format(layer_num)
96+
# we transpose the expert weights for torchtitan optimization purpose
97+
combine_values.append(tt_fqn_map[tt_key].transpose(-1, -2))
98+
99+
value = torch.cat(combine_values, dim=-1)
100+
hf_state_dict[hf_fqn] = value
101+
102+
return hf_state_dict
103+
104+
def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
105+
state_dict = {}
106+
107+
for key, value in hf_state_dict.items():
108+
if "layers" in key:
109+
layer_num = re.search(r"\d+", key).group(0)
110+
key = re.sub(r"(\d+)", "{}", key, count=1)
111+
else:
112+
layer_num = None
113+
114+
if key in self.from_hf_map:
115+
# do direct mapping
116+
if (
117+
key
118+
== "language_model.model.layers.{}.feed_forward.experts.down_proj"
119+
):
120+
# we transpose the expert weights for torchtitan optimization purpose
121+
value = value.transpose(-1, -2)
122+
123+
new_key = self.from_hf_map[key]
124+
if new_key is None:
125+
continue
126+
if layer_num:
127+
new_key = new_key.format(layer_num)
128+
state_dict[new_key] = value
129+
elif (
130+
key
131+
== "language_model.model.layers.{}.feed_forward.experts.gate_up_proj"
132+
):
133+
# handle splitting values
134+
w1, w3 = value.chunk(2, dim=-1)
135+
# we transpose the expert weights for torchtitan optimization purpose
136+
w1, w3 = w1.transpose(-1, -2), w3.transpose(-1, -2)
137+
# split_vals = [val.transpose(-1, -2) for val in split_vals]
138+
state_dict["layers.{}.moe.experts.w1".format(layer_num)] = w1
139+
state_dict["layers.{}.moe.experts.w3".format(layer_num)] = w3
140+
141+
return state_dict
File renamed without changes.

0 commit comments

Comments
 (0)