Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chat: dynamically set kv-cache size #1583

Merged
merged 3 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions litgpt/chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,18 @@ def decode(fabric: L.Fabric, tokenizer: Tokenizer, token_stream: Iterator[torch.
return tokens_generated


def process_prompt(prompt, model, tokenizer, prompt_style, fabric, temperature, top_k, top_p, stop_tokens):
def process_prompt(prompt, model, tokenizer, prompt_style, fabric, temperature, max_new_tokens, top_k, top_p, stop_tokens):
prompt = prompt_style.apply(prompt=prompt)
encoded_prompt = tokenizer.encode(prompt, device=fabric.device)

first_turn = model.mask_cache is None
max_returned_tokens = encoded_prompt.size(0) + max_new_tokens
if first_turn or max_returned_tokens > model.max_seq_length:
model.max_seq_length = max_returned_tokens
model.set_kv_cache(batch_size=1, device=fabric.device)

y = generate(
model, encoded_prompt, model.max_seq_length, temperature=temperature, top_k=top_k, top_p=top_p, stop_tokens=stop_tokens
model, encoded_prompt, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, stop_tokens=stop_tokens
)
fabric.print(">> Reply: ", end="")
t0 = time.perf_counter()
Expand All @@ -140,7 +147,7 @@ def process_prompt(prompt, model, tokenizer, prompt_style, fabric, temperature,
fabric.print()


def interact(multiline, model, tokenizer, prompt_style, fabric, temperature, top_k, top_p, stop_tokens):
def interact(multiline, model, tokenizer, prompt_style, fabric, temperature, max_new_tokens, top_k, top_p, stop_tokens):
while True:
try:
if not multiline:
Expand All @@ -162,13 +169,14 @@ def interact(multiline, model, tokenizer, prompt_style, fabric, temperature, top
if not prompt or prompt in ("!quit", "!exit"):
break

process_prompt(prompt, model, tokenizer, prompt_style, fabric, temperature, top_k, top_p, stop_tokens)
process_prompt(prompt, model, tokenizer, prompt_style, fabric, temperature, max_new_tokens, top_k, top_p, stop_tokens)


@torch.inference_mode()
def main(
checkpoint_dir: Path,
*,
max_new_tokens: int = 50,
top_k: Optional[int] = 200,
top_p: float = 1.0,
temperature: float = 0.8,
Expand All @@ -183,6 +191,7 @@ def main(
Args:
checkpoint_dir: A local path to a directory containing the model weights or a valid model name.
You can get a list of valid model names via the `litgpt download list` command line argument.
max_new_tokens: The number of generation steps to take.
top_k: The number of top most probable tokens to consider in the sampling process.
top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
In top-p sampling, the next token is sampled from the highest probability tokens
Expand Down Expand Up @@ -237,8 +246,6 @@ def main(

with fabric.init_module(empty_init=True):
model = GPT(config)
# enable the kv cache
model.set_kv_cache(batch_size=1)
load_checkpoint(fabric, model, checkpoint_path)
model.eval()

Expand Down Expand Up @@ -272,7 +279,11 @@ def main(
prompt_style=prompt_style,
fabric=fabric,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_k=top_k,
top_p=top_p,
stop_tokens=stop_tokens
)

if fabric.device.type == "cuda":
fabric.print(f"\nMemory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)
4 changes: 2 additions & 2 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,13 @@ def test_main(mocked_input, stop_iteration, fake_checkpoint_dir, monkeypatch, te

out, err = StringIO(), StringIO()
with redirect_stdout(out), redirect_stderr(err):
chat.main(temperature=2.0, top_k=2, top_p=0.9, checkpoint_dir=fake_checkpoint_dir)
chat.main(temperature=2.0, max_new_tokens=10, top_k=2, top_p=0.9, checkpoint_dir=fake_checkpoint_dir)

# decoding is done per each generated item
assert len(tokenizer_mock.return_value.decode.mock_calls) == generate_mock.return_value.numel()
assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value)
assert generate_mock.mock_calls == [
call(ANY, tensor_like, 128, temperature=2.0, top_k=2, top_p=0.9, stop_tokens=([tokenizer_mock.return_value.eos_id],))
call(ANY, tensor_like, 13, temperature=2.0, top_k=2, top_p=0.9, stop_tokens=([tokenizer_mock.return_value.eos_id],))
]
# only the generated result is printed to stdout
assert re.match(r".*Now chatting with Llama 3.*>> .*Reply: foo bar baz", out.getvalue(), re.DOTALL)
Expand Down
Loading