|  | 
|  | 1 | +from typing import Dict | 
|  | 2 | + | 
|  | 3 | +import torch | 
|  | 4 | + | 
|  | 5 | +from torchtune.models.convert_weights import get_mapped_key | 
|  | 6 | + | 
|  | 7 | +_UNSLOTH_TO_META = { | 
|  | 8 | +    "base_model.model.model.layers.{}.mlp.down_proj.lora_A.weight": "layers.{}.feed_forward.w2.lora_a.weight", | 
|  | 9 | +    "base_model.model.model.layers.{}.mlp.down_proj.lora_B.weight": "layers.{}.feed_forward.w2.lora_b.weight", | 
|  | 10 | +    "base_model.model.model.layers.{}.mlp.gate_proj.lora_A.weight": "layers.{}.feed_forward.w1.lora_a.weight", | 
|  | 11 | +    "base_model.model.model.layers.{}.mlp.gate_proj.lora_B.weight": "layers.{}.feed_forward.w1.lora_b.weight", | 
|  | 12 | +    "base_model.model.model.layers.{}.mlp.up_proj.lora_A.weight": "layers.{}.feed_forward.w3.lora_a.weight", | 
|  | 13 | +    "base_model.model.model.layers.{}.mlp.up_proj.lora_B.weight": "layers.{}.feed_forward.w3.lora_b.weight", | 
|  | 14 | +    "base_model.model.model.layers.{}.self_attn.k_proj.lora_A.weight": "layers.{}.attention.wk.lora_a.weight", | 
|  | 15 | +    "base_model.model.model.layers.{}.self_attn.k_proj.lora_B.weight": "layers.{}.attention.wk.lora_b.weight", | 
|  | 16 | +    "base_model.model.model.layers.{}.self_attn.o_proj.lora_A.weight": "layers.{}.attention.wo.lora_a.weight", | 
|  | 17 | +    "base_model.model.model.layers.{}.self_attn.o_proj.lora_B.weight": "layers.{}.attention.wo.lora_b.weight", | 
|  | 18 | +    "base_model.model.model.layers.{}.self_attn.q_proj.lora_A.weight": "layers.{}.attention.wq.lora_a.weight", | 
|  | 19 | +    "base_model.model.model.layers.{}.self_attn.q_proj.lora_B.weight": "layers.{}.attention.wq.lora_b.weight", | 
|  | 20 | +    "base_model.model.model.layers.{}.self_attn.v_proj.lora_A.weight": "layers.{}.attention.wv.lora_a.weight", | 
|  | 21 | +    "base_model.model.model.layers.{}.self_attn.v_proj.lora_B.weight": "layers.{}.attention.wv.lora_b.weight", | 
|  | 22 | +} | 
|  | 23 | + | 
|  | 24 | + | 
|  | 25 | +def unsloth_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | 
|  | 26 | +    """ | 
|  | 27 | +    Convert a state dict from unsloth format to Meta's format. This function | 
|  | 28 | +    doesn't handle any sharding or splitting of state dicts. It follows the | 
|  | 29 | +    state_dict IN -> state_dict OUT pattern. | 
|  | 30 | +
 | 
|  | 31 | +    Args: | 
|  | 32 | +        state_dict (Dict[str, torch.Tensor]): State dict in unsloth format. | 
|  | 33 | +
 | 
|  | 34 | +    Returns: | 
|  | 35 | +        Dict[str, torch.Tensor]: State dict in Meta's format. | 
|  | 36 | +    """ | 
|  | 37 | +    converted_state_dict = {} | 
|  | 38 | + | 
|  | 39 | +    for key, value in state_dict.items(): | 
|  | 40 | +        try: | 
|  | 41 | +            new_key = get_mapped_key(key, _UNSLOTH_TO_META) | 
|  | 42 | +        except: | 
|  | 43 | +            raise ValueError(f"Key {key} not found in mapping") | 
|  | 44 | + | 
|  | 45 | +        converted_state_dict[new_key] = value | 
|  | 46 | +    return converted_state_dict | 
0 commit comments