From b5ef4062ee1378a755d0d6ce616aff09cfb36d52 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 14 Jan 2025 13:14:48 -0800 Subject: [PATCH] fix lazy load --- llms/mlx_lm/chat.py | 9 ++++++--- llms/mlx_lm/utils.py | 4 ++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index 380db082..1fa81ad6 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -15,6 +15,7 @@ DEFAULT_MAX_TOKENS = 256 DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" + def share_message(world, prompt): if world.size() == 1: return prompt @@ -30,7 +31,7 @@ def share_message(world, prompt): if world.rank() == 0: prompt = mx.array(prompt) else: - prompt = mx.array([0]*len(prompt)) + prompt = mx.array([0] * len(prompt)) return mx.distributed.all_sum(size, stream=mx.cpu).tolist() @@ -86,7 +87,10 @@ def main(): ) print(f"Node {world.rank()} of {world.size()}", flush=True) - print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.", flush=True) + print( + f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.", + flush=True, + ) world.barrier() prompt_cache = make_prompt_cache(model, args.max_kv_size) while True: @@ -119,4 +123,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 82788e14..261fd557 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -745,7 +745,7 @@ def load( """ model_path = get_model_path(path_or_hf_repo) - model, config = load_model(model_path, sequential_load, lazy) + model, config = load_model(model_path, lazy=lazy, sequential_load=sequential_load) if adapter_path is not None: model = load_adapters(model, adapter_path) model.eval() @@ -759,7 +759,7 @@ def load( def fetch_from_hub( model_path: Path, lazy: bool = False ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: - model, config = load_model(model_path, lazy) + model, config = load_model(model_path, lazy=lazy) tokenizer = load_tokenizer( model_path, eos_token_ids=config.get("eos_token_id", None) )