diff --git a/tools/convert_fsdp_to_hf.py b/tools/convert_fsdp_to_hf.py index 142730c81..1ed4785b1 100644 --- a/tools/convert_fsdp_to_hf.py +++ b/tools/convert_fsdp_to_hf.py @@ -6,7 +6,11 @@ import torch import torch.distributed.checkpoint as dist_cp -from transformers import AutoConfig, AutoModelForCausalLM +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForImageTextToText, +) from typing_extensions import override @@ -105,6 +109,13 @@ def _strip_best_prefix(keys: list[str], target_keys: set[str]) -> tuple[str, int return best_prefix, best_match +def _build_hf_model(config: AutoConfig) -> torch.nn.Module: + print(f"Detected model type: {config.model_type}") + model_cls = AutoModelForImageTextToText if hasattr(config, "vision_config") else AutoModelForCausalLM + print(f"Loaded with {model_cls.__name__}") + return model_cls.from_config(config, trust_remote_code=True) + + def _convert_fsdp_to_hf( origin_hf_dir: str, input_dir: str, @@ -118,7 +129,7 @@ def _convert_fsdp_to_hf( tensor_items = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)} config = AutoConfig.from_pretrained(origin_hf_dir, trust_remote_code=True) - hf_model = AutoModelForCausalLM.from_config(config) + hf_model = _build_hf_model(config) target_keys = set(hf_model.state_dict().keys()) best_prefix, best_match = _strip_best_prefix(list(tensor_items.keys()), target_keys)