Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions tools/convert_fsdp_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

import torch
import torch.distributed.checkpoint as dist_cp
from transformers import AutoConfig, AutoModelForCausalLM
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForImageTextToText,
)
from typing_extensions import override


Expand Down Expand Up @@ -105,6 +109,13 @@ def _strip_best_prefix(keys: list[str], target_keys: set[str]) -> tuple[str, int
return best_prefix, best_match


def _build_hf_model(config: AutoConfig) -> torch.nn.Module:
print(f"Detected model type: {config.model_type}")
model_cls = AutoModelForImageTextToText if hasattr(config, "vision_config") else AutoModelForCausalLM
print(f"Loaded with {model_cls.__name__}")
return model_cls.from_config(config, trust_remote_code=True)


def _convert_fsdp_to_hf(
origin_hf_dir: str,
input_dir: str,
Expand All @@ -118,7 +129,7 @@ def _convert_fsdp_to_hf(
tensor_items = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}

config = AutoConfig.from_pretrained(origin_hf_dir, trust_remote_code=True)
hf_model = AutoModelForCausalLM.from_config(config)
hf_model = _build_hf_model(config)
target_keys = set(hf_model.state_dict().keys())

best_prefix, best_match = _strip_best_prefix(list(tensor_items.keys()), target_keys)
Expand Down