Skip to content

Commit 4738d9e

Browse files
committed
Merge branch 'main' into release/4.1
2 parents a3651ca + d044111 commit 4738d9e

3 files changed

Lines changed: 43 additions & 24 deletions

File tree

swift/rlhf_trainers/grpo_trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2627,7 +2627,9 @@ def _prepare_model_inputs(self, inputs: 'DataType') -> Dict[str, Any]:
26272627
k: v
26282628
for k, v in inputs.items() if k not in [
26292629
'logits_to_keep', 'completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps',
2630-
'truncated_mask', 'seq_lengths', 'num_items_in_batch', 'rollout_per_token_logps'
2630+
'truncated_mask', 'seq_lengths', 'num_items_in_batch', 'rollout_per_token_logps', 'rollout_logprobs',
2631+
'is_truncated', 'add_eos', 'response_token_ids', 'prompt_id', 'rollout_is_weights', 'finish_reason',
2632+
'request_id'
26312633
]
26322634
}
26332635

swift/rlhf_trainers/utils.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -852,25 +852,17 @@ def prepare_fsdp(model, accelerator, evaluation_mode: bool = True):
852852
return model
853853

854854

855-
def patch_vllm_moe_model_weight_loader(model):
856-
"""
857-
Patch vLLM MoE model to add weight_loader attribute to expert weights.
855+
_moe_model_registry_cache = None
858856

859-
This is a workaround for a bug in vLLM 0.8.2 where MoE weights (w13_weight, w2_weight)
860-
don't have the weight_loader attribute, causing AttributeError during weight loading.
861-
Code adapted from verl/verl/utils/vllm/patch.py
862857

863-
Args:
864-
model: The vLLM model to patch.
865-
"""
866-
import importlib
858+
def _get_moe_model_registry():
867859

868-
# Check if already patched (idempotent)
869-
if getattr(model, '_swift_moe_weight_loader_patched', False):
870-
return
860+
global _moe_model_registry_cache
861+
if _moe_model_registry_cache is not None:
862+
return _moe_model_registry_cache
863+
864+
import importlib
871865

872-
# MoE model configurations: (module_path, class_names, mlp_attr)
873-
# mlp_attr specifies the attribute name for the MoE layer in each model
874866
moe_model_configs = [
875867
('vllm.model_executor.models.deepseek_v2', ('DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM'), 'mlp'),
876868
('vllm.model_executor.models.mixtral', ('MixtralForCausalLM', ), 'block_sparse_moe'),
@@ -881,7 +873,6 @@ def patch_vllm_moe_model_weight_loader(model):
881873
('vllm.model_executor.models.kimi_vl', ('KimiVLForConditionalGeneration', ), 'mlp'),
882874
]
883875

884-
# Build supported models list and MLP attribute mapping
885876
supported_moe_models = []
886877
mlp_attr_mapping = {}
887878

@@ -893,10 +884,32 @@ def patch_vllm_moe_model_weight_loader(model):
893884
model_class = getattr(module, class_name)
894885
supported_moe_models.append(model_class)
895886
mlp_attr_mapping[model_class] = mlp_attr
896-
except (ImportError, AttributeError):
887+
except (ImportError, AttributeError, RuntimeError):
897888
pass
898889

899-
# Early return if no MoE models are supported
890+
_moe_model_registry_cache = (supported_moe_models, mlp_attr_mapping)
891+
return _moe_model_registry_cache
892+
893+
894+
def patch_vllm_moe_model_weight_loader(model):
895+
"""
896+
Patch vLLM MoE model to add weight_loader attribute to expert weights.
897+
898+
This is a workaround for a bug in vLLM 0.8.2 where MoE weights (w13_weight, w2_weight)
899+
don't have the weight_loader attribute, causing AttributeError during weight loading.
900+
Code adapted from verl/verl/utils/vllm/patch.py
901+
902+
Args:
903+
model: The vLLM model to patch.
904+
"""
905+
# Check if already patched (idempotent).
906+
# Note: the flag can be lost when vLLM sleep/wake_up recreates the model
907+
# object, so the expensive import step is cached in _get_moe_model_registry.
908+
if getattr(model, '_swift_moe_weight_loader_patched', False):
909+
return
910+
911+
supported_moe_models, mlp_attr_mapping = _get_moe_model_registry()
912+
900913
if not supported_moe_models:
901914
return
902915

swift/template/templates/gemma.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
242242
register_template(GemmaTemplateMeta(MLLMTemplateType.gemma3n, template_cls=Gemma3nTemplate))
243243

244244

245-
class Gemma4Template(Gemma3Template):
245+
class Gemma4Template(Template):
246246
placeholder_tokens = ['<|image|>', '<|audio|>', '<|video|>']
247247

248248
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
@@ -267,19 +267,23 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
267267
return_tensors='pt',
268268
add_special_tokens=False,
269269
)
270-
splited_tokens = iter(self._split_list(media_inputs['input_ids'][0].tolist(), split_token))
270+
splited_tokens = self._split_list(media_inputs['input_ids'][0].tolist(), split_token)
271271
media_inputs.pop('input_ids')
272272
media_inputs.pop('attention_mask')
273273
input_ids = encoded['input_ids']
274274
labels = encoded['labels']
275275
loss_scale = encoded.get('loss_scale', None)
276276

277-
def _get_new_tokens(i):
278-
return next(splited_tokens)
279-
280277
idx_list = []
281278
for key in ['image', 'video', 'audio']:
282279
idx_list += findall(input_ids, getattr(self.config, f'{key}_token_id'))
280+
sorted_order = sorted(range(len(idx_list)), key=lambda i: idx_list[i])
281+
idx_list = [idx_list[i] for i in sorted_order]
282+
splited_tokens = [splited_tokens[i] for i in sorted_order]
283+
284+
def _get_new_tokens(i):
285+
return splited_tokens[i]
286+
283287
if idx_list:
284288
input_ids, labels, loss_scale = self._extend_tokens(input_ids, labels, loss_scale, idx_list,
285289
_get_new_tokens)

0 commit comments

Comments
 (0)