diff --git a/src/transformers/models/mllama/convert_mllama_weights_to_hf.py b/src/transformers/models/mllama/convert_mllama_weights_to_hf.py index ca22d31ee3ca5e..b2c40e27bb2b40 100644 --- a/src/transformers/models/mllama/convert_mllama_weights_to_hf.py +++ b/src/transformers/models/mllama/convert_mllama_weights_to_hf.py @@ -338,7 +338,11 @@ def write_model( print(f"Fetching all parameters from the checkpoint at {input_base_path}...") if num_shards == 1: - loaded = [torch.load(os.path.join(input_base_path, "consolidated.pth"), map_location="cpu", mmap=True)] + if os.path.exists(os.path.join(input_base_path, "consolidated.00.pth")): + path = os.path.join(input_base_path, "consolidated.00.pth") + else: + path = os.path.join(input_base_path, "consolidated.pth") + loaded = [torch.load(path, map_location="cpu", mmap=True)] else: loaded = [ torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu", mmap=True)