|  | 
|  | 1 | +from typing import Dict | 
|  | 2 | + | 
|  | 3 | +import torch | 
|  | 4 | + | 
|  | 5 | +from safetensors.torch import load_file | 
|  | 6 | +from torchtune.models.convert_weights import get_mapped_key | 
|  | 7 | + | 
|  | 8 | +_UNSLOTH_TO_META = { | 
|  | 9 | +    "base_model.model.model.layers.{}.mlp.down_proj.lora_A.weight": "layers.{}.feed_forward.w2.lora_a.weight", | 
|  | 10 | +    "base_model.model.model.layers.{}.mlp.down_proj.lora_B.weight": "layers.{}.feed_forward.w2.lora_b.weight", | 
|  | 11 | +    "base_model.model.model.layers.{}.mlp.gate_proj.lora_A.weight": "layers.{}.feed_forward.w1.lora_a.weight", | 
|  | 12 | +    "base_model.model.model.layers.{}.mlp.gate_proj.lora_B.weight": "layers.{}.feed_forward.w1.lora_b.weight", | 
|  | 13 | +    "base_model.model.model.layers.{}.mlp.up_proj.lora_A.weight": "layers.{}.feed_forward.w3.lora_a.weight", | 
|  | 14 | +    "base_model.model.model.layers.{}.mlp.up_proj.lora_B.weight": "layers.{}.feed_forward.w3.lora_b.weight", | 
|  | 15 | +    "base_model.model.model.layers.{}.self_attn.k_proj.lora_A.weight": "layers.{}.attention.wk.lora_a.weight", | 
|  | 16 | +    "base_model.model.model.layers.{}.self_attn.k_proj.lora_B.weight": "layers.{}.attention.wk.lora_b.weight", | 
|  | 17 | +    "base_model.model.model.layers.{}.self_attn.o_proj.lora_A.weight": "layers.{}.attention.wo.lora_a.weight", | 
|  | 18 | +    "base_model.model.model.layers.{}.self_attn.o_proj.lora_B.weight": "layers.{}.attention.wo.lora_b.weight", | 
|  | 19 | +    "base_model.model.model.layers.{}.self_attn.q_proj.lora_A.weight": "layers.{}.attention.wq.lora_a.weight", | 
|  | 20 | +    "base_model.model.model.layers.{}.self_attn.q_proj.lora_B.weight": "layers.{}.attention.wq.lora_b.weight", | 
|  | 21 | +    "base_model.model.model.layers.{}.self_attn.v_proj.lora_A.weight": "layers.{}.attention.wv.lora_a.weight", | 
|  | 22 | +    "base_model.model.model.layers.{}.self_attn.v_proj.lora_B.weight": "layers.{}.attention.wv.lora_b.weight", | 
|  | 23 | +} | 
|  | 24 | + | 
|  | 25 | + | 
|  | 26 | +def unsloth_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | 
|  | 27 | +    """ | 
|  | 28 | +    Convert a state dict from unsloth format to Meta's format. This function | 
|  | 29 | +    doesn't handle any sharding or splitting of state dicts. It follows the | 
|  | 30 | +    state_dict IN -> state_dict OUT pattern. | 
|  | 31 | +
 | 
|  | 32 | +    Args: | 
|  | 33 | +        state_dict (Dict[str, torch.Tensor]): State dict in unsloth format. | 
|  | 34 | +
 | 
|  | 35 | +    Returns: | 
|  | 36 | +        Dict[str, torch.Tensor]: State dict in Meta's format. | 
|  | 37 | +    """ | 
|  | 38 | +    converted_state_dict = {} | 
|  | 39 | + | 
|  | 40 | +    for key, value in state_dict.items(): | 
|  | 41 | +        try: | 
|  | 42 | +            new_key = get_mapped_key(key, _UNSLOTH_TO_META) | 
|  | 43 | +        except Exception as e: | 
|  | 44 | +            raise ValueError(f"Key {key} not found in mapping") from e | 
|  | 45 | + | 
|  | 46 | +        converted_state_dict[new_key] = value | 
|  | 47 | +    return converted_state_dict | 
|  | 48 | + | 
|  | 49 | + | 
|  | 50 | +def load_and_convert_unsloth_to_meta(checkpoint_path: str) -> Dict[str, torch.Tensor]: | 
|  | 51 | +    """ | 
|  | 52 | +    Load a checkpoint file and convert it to Meta's format. | 
|  | 53 | +
 | 
|  | 54 | +    Args: | 
|  | 55 | +        checkpoint_path (str): Path to the checkpoint file. | 
|  | 56 | +
 | 
|  | 57 | +    Returns: | 
|  | 58 | +        Dict[str, torch.Tensor]: State dict in Meta's format. | 
|  | 59 | +    """ | 
|  | 60 | +    state_dict = load_file(checkpoint_path) | 
|  | 61 | +    return unsloth_to_meta(state_dict) | 
0 commit comments