From ac322cfd17b2d0aa42276fcd5d33ce25852884e5 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 5 Nov 2024 13:09:34 -0800 Subject: [PATCH] Make the chat distributed --- llms/mlx_lm/chat.py | 57 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 48 insertions(+), 9 deletions(-) diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index 85d32d5fc..c841287fd 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -15,6 +15,36 @@ DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" +def send_chars(x, peer, group): + x = mx.concatenate([x, mx.ones(16384 - len(x), dtype=x.dtype)]) + mx.eval(mx.distributed.send(x, peer, group=group)) + + +def recv_chars(peer, group): + x = mx.distributed.recv((16384,), peer, group=group) + mx.eval(x) + x = bytes(x) + idx = x.index(b'\x00'*4) + return x[:idx].decode() + + +def share_message(world, prompt): + if world.size() == 1: + return + + if world.rank() == 0: + prompt_array = mx.array(prompt.encode()) + for i in range(1, world.size()): + send_chars(prompt_array, i, world) + world.barrier() + + else: + prompt = recv_chars(0, world) + world.barrier() + + return prompt + + def setup_arg_parser(): """Set up and return the argument parser.""" parser = argparse.ArgumentParser(description="Chat with an LLM") @@ -53,6 +83,7 @@ def setup_arg_parser(): def main(): + world = mx.distributed.init() parser = setup_arg_parser() args = parser.parse_args() @@ -62,18 +93,23 @@ def main(): args.model, adapter_path=args.adapter_path, tokenizer_config={"trust_remote_code": True}, + sequential_load=mx.distributed.init().size() > 1, ) + print(f"Node {world.rank()} of {world.size()}", flush=True) print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.") prompt_cache = make_prompt_cache(model, args.max_kv_size) while True: - query = input(">> ") - if query == "q": - break - messages = [{"role": "user", "content": query}] - prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) + prompt = None + if world.rank() == 0: + query = input(">> ") + if query == "q": + break + messages = [{"role": "user", "content": query}] + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + prompt = share_message(world, prompt) for response in stream_generate( model, tokenizer, @@ -83,9 +119,12 @@ def main(): top_p=args.top_p, prompt_cache=prompt_cache, ): - print(response, flush=True, end="") - print() + if world.rank() == 0: + print(response, flush=True, end="") + if world.rank() == 0: + print() if __name__ == "__main__": main() +