Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Optimization] Advance parser concurrently with model forward pass #1065

Merged
merged 22 commits into from
Nov 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
820ed87
Simplify Engine.__call__ to remove second call to get_next_token
hudson-ai Oct 25, 2024
9f434ad
simplify parser loop since we know mask will be non-none on every ite…
hudson-ai Oct 25, 2024
e5ee621
simplify model.__call__ loop again
hudson-ai Oct 25, 2024
1b49dd3
prototype concurrent parser
hudson-ai Oct 25, 2024
f9d38fd
generator cleanup
hudson-ai Oct 26, 2024
facafb7
move Mock temperature hook from get_next_token to sample_with_tempera…
hudson-ai Oct 26, 2024
28053c1
wrong assert
hudson-ai Oct 26, 2024
baa1d6b
silence cleanup exceptions in garbage collection
hudson-ai Oct 26, 2024
b1545ed
Simplify ByteParser
hudson-ai Oct 26, 2024
d97d72a
Allow non-concurrent path with get_next_token
hudson-ai Oct 26, 2024
f59f1fb
wrong assert
hudson-ai Oct 26, 2024
8aa882c
Merge branch 'main' into parallel_parser
hudson-ai Oct 28, 2024
f8779e7
test associativity on get_logits rather than get_next_token
hudson-ai Oct 28, 2024
fe33742
fix associativity test to get the args of the FIRST call
hudson-ai Oct 29, 2024
e0d8e69
use has_pending_stop to prevent unnecessary forward pass
hudson-ai Oct 31, 2024
1b2c5e2
comment
hudson-ai Oct 31, 2024
e718137
prevent parser cleanup from raising exceptions at system exit
hudson-ai Oct 31, 2024
27b3d19
bump llg
hudson-ai Oct 31, 2024
a22a40c
move LLInterpreterResponse validation into thread with mid_process to…
hudson-ai Nov 5, 2024
93cff1f
add some comments
hudson-ai Nov 5, 2024
90e5b6e
fix exception
hudson-ai Nov 5, 2024
42a5e5e
Merge branch 'main' into parallel_parser
hudson-ai Nov 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 85 additions & 49 deletions guidance/_parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
from typing import Any, Generator, Optional, Tuple, Union
from typing import Any, Generator, Optional, Union
from concurrent.futures import ThreadPoolExecutor, Future

import llguidance # type: ignore[import-untyped]
import numpy as np
Expand All @@ -11,7 +12,6 @@
from .models._byte_tokenizer import ByteTokenizer
from .models._tokenizer import Tokenizer


class TokenParserException(Exception):
pass

Expand Down Expand Up @@ -52,8 +52,10 @@ def __init__(
serialized_grammar,
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
)
self._threadpool = ThreadPoolExecutor(max_workers=1)
self._generator = self._parse(prompt, ensure_bos_token)
self._done = False
self._has_pending_stop = False

def is_accepting(self) -> bool:
return self.ll_interpreter.is_accepting()
Expand All @@ -63,12 +65,13 @@ def done(self) -> bool:

def advance(
self, token: Optional[int]
) -> Tuple[Optional[GenData], EngineCallResponse]:
try:
return self._generator.send(token)
except StopIteration as e:
self._done = True
return None, e.value
) -> tuple[list[int], Future[tuple[Optional[bytes], LLInterpreterResponse]]]:
if self.done():
raise TokenParserException("Cannot advance on a done parser")
return self._generator.send(token)

def has_pending_stop(self) -> bool:
return self._has_pending_stop

def _process_prompt(self, prompt: bytes, ensure_bos_token: bool) -> list[int]:
prompt_tokens = self.ll_interpreter.process_prompt(
Expand All @@ -84,55 +87,78 @@ def _process_prompt(self, prompt: bytes, ensure_bos_token: bool) -> list[int]:

return self.tokenizer.recode(prompt_tokens)

def mid_process(self) -> tuple[Optional[bytes], LLInterpreterResponse]:
mask, ll_response_string = self.ll_interpreter.mid_process()
ll_response = LLInterpreterResponse.model_validate_json(ll_response_string)
return mask, ll_response

def _parse(
self,
prompt: bytes,
ensure_bos_token: bool,
) -> Generator[Tuple[Optional[GenData], EngineCallResponse], Optional[int], EngineCallResponse]:
) -> Generator[
tuple[
list[int],
Future[tuple[Optional[bytes], LLInterpreterResponse]],
],
Optional[int],
None
]:
tokens = self._process_prompt(prompt=prompt, ensure_bos_token=ensure_bos_token)

while True:
mask, resp = self.ll_interpreter.mid_process()
r = LLInterpreterResponse.model_validate_json(resp)
response = r.progress.to_engine_call_response()
if r.stop:
# Note: need to call/set has_pending_stop before spinning up the mid_process future
# as the two methods cannot be called concurrently
self._has_pending_stop = self.ll_interpreter.has_pending_stop()
mid_process_future = self._threadpool.submit(self.mid_process)
token = yield (tokens, mid_process_future)

# Upstairs should have already waited on this future
mask, ll_response = mid_process_future.result()

if ll_response.stop:
# This is the only case in which the mask is None
assert mask is None
# If we're done, our caller should NOT send us a token
if token is not None:
raise TokenParserException(f"Expected None, got token {token}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if mask is None, isn't it in accepting mode? any tokens should be accepted?

Copy link
Collaborator Author

@hudson-ai hudson-ai Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mask should never be none unless the parser is actually done (i.e. we should not be accepting ANY tokens, as the loop should be stopping). This condition should be equivalent to ll_response.stop if we were to parse the string in the second slot of the future above

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note the .cleanup code, which is currently responsible for sending the final None token to get the generator loop to break. Let me know if you have any better ideas on how to structure it!

self._done = True
break

if mask is not None:
assert r.temperature is not None
gen_data = GenData(
tokens=tokens,
mask=mask,
temperature=r.temperature,
assert mask is not None
if token is None:
raise TokenParserException("Expected token, got None")
if not mask[token]:
# Note: we could punt this probem to ll_interpreter.post_process,
# but it's a bit clearer to handle it here
raise InvalidTokenException(
token=token,
valid_tokens=[i for i in range(len(mask)) if mask[i]],
prompt_tokens=tokens
)
# Send caller the mask and response; wait for token
token = yield (gen_data, response)
if token is None:
raise TokenParserException("Expected token, got None")
if not mask[token]:
# Note: we could punt this probem to ll_interpreter.post_process,
# but it's a bit clearer to handle it here
raise InvalidTokenException(token, gen_data.valid_next_tokens, tokens)
else:
gen_data = None
token = yield (gen_data, response)
if token is not None:
raise TokenParserException(f"Expected None, got token {token}")

backtrack, ff_tokens = self.ll_interpreter.post_process(token)
if backtrack:
tokens = tokens[:-backtrack]
tokens = tokens + ff_tokens

def cleanup(self):
# Rather than having our caller send us None at the end, we'll handle that internally
# so we can (1) verify that the generator actually stops and (2) check the stop reason
# and raise if needed
if not self.done():
try:
self._generator.send(None)
except StopIteration:
pass
if not self.done():
raise TokenParserException("Tried to cleanup but parser is not done")
stop_reason = self.ll_interpreter.stop_reason()
if stop_reason not in {"NoExtension", "EndOfSentence"}:
# TODO: extend exception handling
# Will raise if there is some "bad" stop reason (like hit token limit) OR we're NOT stopped.
# TODO: raise specific exceptions for reasons such as MaxTokensTotal
raise TokenParserException(f"Unexpected stop reason: {stop_reason}")

return response


class ByteParserException(Exception):
def __init__(self, *args, **kwargs):
self.current_byte = kwargs.pop("current_byte", None)
Expand All @@ -155,6 +181,8 @@ def __init__(
self.pos = 0
self._variables: dict[str, Any] = {}
self._variables_log_probs: dict[str, Any] = {}
# Prime the parser
self._advance(None)
self.consume_bytes(prompt)

def matched(self) -> bool:
Expand All @@ -179,14 +207,26 @@ def next_byte_mask(self) -> NDArray[np.uint8]:
mask[t[0]] = 1
return mask

def consume_bytes(self, bts: bytes) -> None:
# Run underlying ll_parser and fast-forward all of our bytes
# until we have a "choice" (generation step) to make
while self.gen_data is None and not self.token_parser.done():
self.gen_data, response = self.token_parser.advance(None)
self._update_capture(response)
self.bytes += response.new_bytes
def _advance(self, token: Optional[int]) -> None:
tokens, mid_process_fut = self.token_parser.advance(token)
mask, ll_response = mid_process_fut.result()
if ll_response.stop:
assert mask is None
self.token_parser.cleanup()
self.gen_data = None
else:
assert mask is not None
assert ll_response.temperature is not None
self.gen_data = GenData(
tokens=tokens,
mask=mask,
temperature=ll_response.temperature,
)
response = ll_response.progress.to_engine_call_response()
self._update_capture(response)
self.bytes += response.new_bytes

def consume_bytes(self, bts: bytes) -> None:
if not bts:
return

Expand Down Expand Up @@ -228,9 +268,7 @@ def consume_bytes(self, bts: bytes) -> None:
consumed_bytes=self.bytes[: self.pos],
)
# Byte was good, have ll_parser consume it so we can advance further
self.gen_data, response = self.token_parser.advance(b)
self._update_capture(response)
self.bytes += response.new_bytes
self._advance(b)

# Run consume_bytes to advance ll_parser and consume the next byte
self.consume_bytes(bts)
Expand All @@ -241,9 +279,7 @@ def force_done(self):
if self.token_parser.done():
return

self.gen_data, response = self.token_parser.advance(self.tokenizer.eos_token_id)
self._update_capture(response)
self.bytes += response.new_bytes
self._advance(self.tokenizer.eos_token_id)
if not self.token_parser.done() or not self.matched():
raise ByteParserException("Hit end of input before reaching a valid state")

Expand Down
4 changes: 2 additions & 2 deletions guidance/models/_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def __init__(self, tokenizer, byte_patterns, compute_log_probs, force):
# seed the random number generator
self._rand_generator = np.random.default_rng(seed=42)

def get_next_token(self, token_ids: list[int], mask: Optional[bytes], temperature: float) -> int:
def sample_with_temperature(self, logits, mask, temperature):
self.called_temperatures.append(temperature)
return super().get_next_token(token_ids, mask, temperature)
return super().sample_with_temperature(logits, mask, temperature)

def get_logits(self, token_ids: list[int]) -> np.ndarray:
"""Pretends to compute the logits for the given token state."""
Expand Down
102 changes: 73 additions & 29 deletions guidance/models/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,42 +133,86 @@ def __call__(self, prompt, grammar, ensure_bos_token=True) -> Iterator[EngineCal
"""
parser = self.start(prompt, grammar, ensure_bos_token)

has_get_logits = True
token = None
while not parser.done():
gen_data, response = parser.advance(token)

if gen_data is not None:
if parser.is_accepting() and self.tokenizer.eos_token_id is not None:
# Whenever we are in an accepting state, we will allow the model to generate whatever it wants
# but we will treat any "illegal" tokens as EOS, allowing the model to finish gracefully.
assert gen_data.mask[self.tokenizer.eos_token_id]
token = self.get_next_token(
token_ids=gen_data.tokens,
mask=None,
temperature=gen_data.temperature
)
if not gen_data.mask[token]:
token = self.tokenizer.eos_token_id
else:
token = self.get_next_token(
token_ids=gen_data.tokens,
mask=gen_data.mask,
temperature=gen_data.temperature
)
while True:
tokens, mid_process_fut = parser.advance(token)

# Note that has_pending_stop implies that the response is a stop response,
# but the converse is not true. We can therefore avoid some (but not all)
# unnecessary calls to get_logits on the final iteration.
has_pending_stop = parser.has_pending_stop()

if has_get_logits and not has_pending_stop:
try:
logits = self.get_logits(token_ids=tokens)
except NotImplementedError:
# Permanently fall-back to get_next_token if get_logits is not implemented
has_get_logits = False
logits = None
else:
token = None
logits = None

# Important: don't wait on this future until after getting the logits;
# this allows the mask to be built concurrently with model inference
mask, ll_response = mid_process_fut.result()

engine_response = ll_response.progress.to_engine_call_response()
yield engine_response

if ll_response.stop:
assert mask is None
# May raise an exception if the parser is in an bad state!
parser.cleanup()
# Ensure we break AFTER yielding the final response
break

# If there was a pending stop, we should have broken out of the loop
assert not has_pending_stop

# Help the type checker: assert that everything we need to get the next token is not None
assert mask is not None
assert ll_response.temperature is not None

can_finish_early = parser.is_accepting() and self.tokenizer.eos_token_id is not None

if can_finish_early:
# Type checker needs some help
assert self.tokenizer.eos_token_id is not None
# Should be equivalent to parser.is_accepting()
assert mask[self.tokenizer.eos_token_id]
# Whenever we are in an accepting state, we will allow the model to generate whatever it wants
# but we will treat any "illegal" tokens as EOS, allowing the model to finish gracefully.
# Hence, mask must be None
mask_for_sampling = None
else:
mask_for_sampling = mask

if logits is not None:
token = self.sample_with_temperature(
logits=logits,
mask=mask_for_sampling,
temperature=ll_response.temperature,
)
else:
token = self.get_next_token(
tokens,
mask_for_sampling,
ll_response.temperature
)

if can_finish_early and not mask[token]:
# Type checker needs some help
assert self.tokenizer.eos_token_id is not None
token = self.tokenizer.eos_token_id

yield response

def get_next_token(self, token_ids: list[int], mask: Optional[bytes], temperature: float) -> int:
"""Base implementation for getting the next token from the model which calls get_logits and sample_with_temperature.
Subclasses may override this method, e.g. if they use external APIs that do not support getting logits directly.
"""
logits = self.get_logits(token_ids)
token = self.sample_with_temperature(logits, mask, temperature)
return token
# Prefer to implement get_logits over get_next_token as it allows for concurrent mask computation
raise NotImplementedError

def get_logits(self, token_ids: list[int]) -> np.ndarray:
# Prefer to implement get_logits over get_next_token as it allows for concurrent mask computation
raise NotImplementedError

def sample_with_temperature(self, logits: np.ndarray, mask: Optional[bytes], temperature: float) -> int:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"referencing",
"requests",
"tiktoken>=0.3",
"llguidance>=0.1.7",
"llguidance>=0.3.0",
]

# Our basic list of 'extras'
Expand Down
10 changes: 5 additions & 5 deletions tests/model_integration/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,18 @@ def test_associativity(selected_model: models.Model):
REMOTE_MODELS = [models.AzureGuidance]
for rm in REMOTE_MODELS:
if isinstance(selected_model, rm):
pytest.skip("Method get_next_token not available for remote models")
pytest.skip("Method get_logits not available for remote models")
prompt = "pi = "
grammar = gen("number", regex=r"\d")
engine = selected_model.engine

with patch.object(engine, "get_next_token", side_effect=engine.get_next_token) as get_next_token_1:
with patch.object(engine, "get_logits", side_effect=engine.get_logits) as get_logits_1:
_ = selected_model + (prompt + grammar)
prompt_tokens_1 = get_next_token_1.call_args.kwargs["token_ids"]
prompt_tokens_1 = get_logits_1.call_args_list[0].kwargs["token_ids"]

with patch.object(engine, "get_next_token", side_effect=engine.get_next_token) as get_next_token_2:
with patch.object(engine, "get_logits", side_effect=engine.get_logits) as get_logits_2:
_ = (selected_model + prompt) + grammar
prompt_tokens_2 = get_next_token_2.call_args.kwargs["token_ids"]
prompt_tokens_2 = get_logits_2.call_args_list[0].kwargs["token_ids"]

# Main assertion: the prompt tokens should be the same
assert prompt_tokens_1 == prompt_tokens_2
Expand Down
Loading