Skip to content

Commit

Permalink
Improve get_mapped_key heuristic (#1525)
Browse files Browse the repository at this point in the history
  • Loading branch information
pbontrager authored Sep 9, 2024
1 parent 68d4f3e commit 66590b4
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion torchtune/models/convert_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@

def get_mapped_key(key: str, mapping_dict: Dict[str, str]) -> str:
try:
if "layers" in key:
# Checks if there is a layer # in the key
if any(k.isdigit() for k in key.split(".")):
# Replace layer number with "{}" to create key for lookup
abstract_key = re.sub(r"(\.\d+)", ".{}", key)
layer_num = re.search(r"\d+", key).group(0)
Expand Down

0 comments on commit 66590b4

Please sign in to comment.