Skip to content

Commit

Permalink
Make the chat distributed
Browse files Browse the repository at this point in the history
  • Loading branch information
angeloskath committed Nov 6, 2024
1 parent 1c52719 commit 6041c7e
Showing 1 changed file with 59 additions and 21 deletions.
80 changes: 59 additions & 21 deletions llms/mlx_lm/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,37 @@
import mlx.core as mx

from .models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache
from .utils import load, stream_generate
from .utils import load, stream_generate, wired_limit

DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
DEFAULT_MAX_TOKENS = 256
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"

MAX_PROMPT_CHARS = 16384


def share_message(world, prompt):
if world.size() == 1:
return prompt

if world.rank() == 0:
prompt_array = mx.array(prompt.encode())
prompt_array = mx.concatenate(
[prompt_array, mx.zeros(MAX_PROMPT_CHARS - len(prompt_array), dtype=mx.uint8)]
)

else:
prompt_array = mx.zeros(MAX_PROMPT_CHARS, dtype=mx.uint8)

with mx.stream(mx.cpu):
prompt_array = mx.distributed.all_sum(prompt_array)
mx.eval(prompt_array)
prompt = bytes(prompt_array)
idx = prompt.index(b'\x00'*4)
return prompt[:idx].decode()


def setup_arg_parser():
"""Set up and return the argument parser."""
Expand Down Expand Up @@ -53,6 +76,7 @@ def setup_arg_parser():


def main():
world = mx.distributed.init()
parser = setup_arg_parser()
args = parser.parse_args()

Expand All @@ -62,30 +86,44 @@ def main():
args.model,
adapter_path=args.adapter_path,
tokenizer_config={"trust_remote_code": True},
sequential_load=mx.distributed.init().size() > 1,
)

print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.")
print(f"Node {world.rank()} of {world.size()}", flush=True)
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.", flush=True)
world.barrier()
prompt_cache = make_prompt_cache(model, args.max_kv_size)
while True:
query = input(">> ")
if query == "q":
break
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
for response in stream_generate(
model,
tokenizer,
prompt,
args.max_tokens,
temp=args.temp,
top_p=args.top_p,
prompt_cache=prompt_cache,
):
print(response, flush=True, end="")
print()
with wired_limit(model):
while True:
prompt = None
if world.rank() == 0:
query = input(">> ")
if query == "q":
prompt = query
else:
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
prompt = share_message(world, prompt)
if prompt == "q":
break
for response in stream_generate(
model,
tokenizer,
prompt,
args.max_tokens,
temp=args.temp,
top_p=args.top_p,
prompt_cache=prompt_cache,
):
if world.rank() == 0:
print(response, flush=True, end="")
if world.rank() == 0:
print()
mx.synchronize()


if __name__ == "__main__":
main()

0 comments on commit 6041c7e

Please sign in to comment.