Skip to content

Commit

Permalink
fix lazy load
Browse files Browse the repository at this point in the history
  • Loading branch information
awni committed Jan 14, 2025
1 parent 1344236 commit b5ef406
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
9 changes: 6 additions & 3 deletions llms/mlx_lm/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
DEFAULT_MAX_TOKENS = 256
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"


def share_message(world, prompt):
if world.size() == 1:
return prompt
Expand All @@ -30,7 +31,7 @@ def share_message(world, prompt):
if world.rank() == 0:
prompt = mx.array(prompt)
else:
prompt = mx.array([0]*len(prompt))
prompt = mx.array([0] * len(prompt))
return mx.distributed.all_sum(size, stream=mx.cpu).tolist()


Expand Down Expand Up @@ -86,7 +87,10 @@ def main():
)

print(f"Node {world.rank()} of {world.size()}", flush=True)
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.", flush=True)
print(
f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.",
flush=True,
)
world.barrier()
prompt_cache = make_prompt_cache(model, args.max_kv_size)
while True:
Expand Down Expand Up @@ -119,4 +123,3 @@ def main():

if __name__ == "__main__":
main()

4 changes: 2 additions & 2 deletions llms/mlx_lm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ def load(
"""
model_path = get_model_path(path_or_hf_repo)

model, config = load_model(model_path, sequential_load, lazy)
model, config = load_model(model_path, lazy=lazy, sequential_load=sequential_load)
if adapter_path is not None:
model = load_adapters(model, adapter_path)
model.eval()
Expand All @@ -759,7 +759,7 @@ def load(
def fetch_from_hub(
model_path: Path, lazy: bool = False
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
model, config = load_model(model_path, lazy)
model, config = load_model(model_path, lazy=lazy)
tokenizer = load_tokenizer(
model_path, eos_token_ids=config.get("eos_token_id", None)
)
Expand Down

0 comments on commit b5ef406

Please sign in to comment.