diff --git a/guidance/_parser.py b/guidance/_parser.py index e6a22671d..910315feb 100644 --- a/guidance/_parser.py +++ b/guidance/_parser.py @@ -1,6 +1,7 @@ import json import os -from typing import Any, Generator, Optional, Tuple, Union +from typing import Any, Generator, Optional, Union +from concurrent.futures import ThreadPoolExecutor, Future import llguidance # type: ignore[import-untyped] import numpy as np @@ -11,7 +12,6 @@ from .models._byte_tokenizer import ByteTokenizer from .models._tokenizer import Tokenizer - class TokenParserException(Exception): pass @@ -52,8 +52,10 @@ def __init__( serialized_grammar, log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")), ) + self._threadpool = ThreadPoolExecutor(max_workers=1) self._generator = self._parse(prompt, ensure_bos_token) self._done = False + self._has_pending_stop = False def is_accepting(self) -> bool: return self.ll_interpreter.is_accepting() @@ -63,12 +65,13 @@ def done(self) -> bool: def advance( self, token: Optional[int] - ) -> Tuple[Optional[GenData], EngineCallResponse]: - try: - return self._generator.send(token) - except StopIteration as e: - self._done = True - return None, e.value + ) -> tuple[list[int], Future[tuple[Optional[bytes], LLInterpreterResponse]]]: + if self.done(): + raise TokenParserException("Cannot advance on a done parser") + return self._generator.send(token) + + def has_pending_stop(self) -> bool: + return self._has_pending_stop def _process_prompt(self, prompt: bytes, ensure_bos_token: bool) -> list[int]: prompt_tokens = self.ll_interpreter.process_prompt( @@ -84,55 +87,78 @@ def _process_prompt(self, prompt: bytes, ensure_bos_token: bool) -> list[int]: return self.tokenizer.recode(prompt_tokens) + def mid_process(self) -> tuple[Optional[bytes], LLInterpreterResponse]: + mask, ll_response_string = self.ll_interpreter.mid_process() + ll_response = LLInterpreterResponse.model_validate_json(ll_response_string) + return mask, ll_response def _parse( self, prompt: bytes, ensure_bos_token: bool, - ) -> Generator[Tuple[Optional[GenData], EngineCallResponse], Optional[int], EngineCallResponse]: + ) -> Generator[ + tuple[ + list[int], + Future[tuple[Optional[bytes], LLInterpreterResponse]], + ], + Optional[int], + None + ]: tokens = self._process_prompt(prompt=prompt, ensure_bos_token=ensure_bos_token) while True: - mask, resp = self.ll_interpreter.mid_process() - r = LLInterpreterResponse.model_validate_json(resp) - response = r.progress.to_engine_call_response() - if r.stop: + # Note: need to call/set has_pending_stop before spinning up the mid_process future + # as the two methods cannot be called concurrently + self._has_pending_stop = self.ll_interpreter.has_pending_stop() + mid_process_future = self._threadpool.submit(self.mid_process) + token = yield (tokens, mid_process_future) + + # Upstairs should have already waited on this future + mask, ll_response = mid_process_future.result() + + if ll_response.stop: + # This is the only case in which the mask is None + assert mask is None + # If we're done, our caller should NOT send us a token + if token is not None: + raise TokenParserException(f"Expected None, got token {token}") + self._done = True break - if mask is not None: - assert r.temperature is not None - gen_data = GenData( - tokens=tokens, - mask=mask, - temperature=r.temperature, + assert mask is not None + if token is None: + raise TokenParserException("Expected token, got None") + if not mask[token]: + # Note: we could punt this probem to ll_interpreter.post_process, + # but it's a bit clearer to handle it here + raise InvalidTokenException( + token=token, + valid_tokens=[i for i in range(len(mask)) if mask[i]], + prompt_tokens=tokens ) - # Send caller the mask and response; wait for token - token = yield (gen_data, response) - if token is None: - raise TokenParserException("Expected token, got None") - if not mask[token]: - # Note: we could punt this probem to ll_interpreter.post_process, - # but it's a bit clearer to handle it here - raise InvalidTokenException(token, gen_data.valid_next_tokens, tokens) - else: - gen_data = None - token = yield (gen_data, response) - if token is not None: - raise TokenParserException(f"Expected None, got token {token}") backtrack, ff_tokens = self.ll_interpreter.post_process(token) if backtrack: tokens = tokens[:-backtrack] tokens = tokens + ff_tokens + def cleanup(self): + # Rather than having our caller send us None at the end, we'll handle that internally + # so we can (1) verify that the generator actually stops and (2) check the stop reason + # and raise if needed + if not self.done(): + try: + self._generator.send(None) + except StopIteration: + pass + if not self.done(): + raise TokenParserException("Tried to cleanup but parser is not done") stop_reason = self.ll_interpreter.stop_reason() if stop_reason not in {"NoExtension", "EndOfSentence"}: - # TODO: extend exception handling + # Will raise if there is some "bad" stop reason (like hit token limit) OR we're NOT stopped. + # TODO: raise specific exceptions for reasons such as MaxTokensTotal raise TokenParserException(f"Unexpected stop reason: {stop_reason}") - return response - - class ByteParserException(Exception): def __init__(self, *args, **kwargs): self.current_byte = kwargs.pop("current_byte", None) @@ -155,6 +181,8 @@ def __init__( self.pos = 0 self._variables: dict[str, Any] = {} self._variables_log_probs: dict[str, Any] = {} + # Prime the parser + self._advance(None) self.consume_bytes(prompt) def matched(self) -> bool: @@ -179,14 +207,26 @@ def next_byte_mask(self) -> NDArray[np.uint8]: mask[t[0]] = 1 return mask - def consume_bytes(self, bts: bytes) -> None: - # Run underlying ll_parser and fast-forward all of our bytes - # until we have a "choice" (generation step) to make - while self.gen_data is None and not self.token_parser.done(): - self.gen_data, response = self.token_parser.advance(None) - self._update_capture(response) - self.bytes += response.new_bytes + def _advance(self, token: Optional[int]) -> None: + tokens, mid_process_fut = self.token_parser.advance(token) + mask, ll_response = mid_process_fut.result() + if ll_response.stop: + assert mask is None + self.token_parser.cleanup() + self.gen_data = None + else: + assert mask is not None + assert ll_response.temperature is not None + self.gen_data = GenData( + tokens=tokens, + mask=mask, + temperature=ll_response.temperature, + ) + response = ll_response.progress.to_engine_call_response() + self._update_capture(response) + self.bytes += response.new_bytes + def consume_bytes(self, bts: bytes) -> None: if not bts: return @@ -228,9 +268,7 @@ def consume_bytes(self, bts: bytes) -> None: consumed_bytes=self.bytes[: self.pos], ) # Byte was good, have ll_parser consume it so we can advance further - self.gen_data, response = self.token_parser.advance(b) - self._update_capture(response) - self.bytes += response.new_bytes + self._advance(b) # Run consume_bytes to advance ll_parser and consume the next byte self.consume_bytes(bts) @@ -241,9 +279,7 @@ def force_done(self): if self.token_parser.done(): return - self.gen_data, response = self.token_parser.advance(self.tokenizer.eos_token_id) - self._update_capture(response) - self.bytes += response.new_bytes + self._advance(self.tokenizer.eos_token_id) if not self.token_parser.done() or not self.matched(): raise ByteParserException("Hit end of input before reaching a valid state") diff --git a/guidance/models/_mock.py b/guidance/models/_mock.py index 0f1e48b41..a2ae59a3a 100644 --- a/guidance/models/_mock.py +++ b/guidance/models/_mock.py @@ -80,9 +80,9 @@ def __init__(self, tokenizer, byte_patterns, compute_log_probs, force): # seed the random number generator self._rand_generator = np.random.default_rng(seed=42) - def get_next_token(self, token_ids: list[int], mask: Optional[bytes], temperature: float) -> int: + def sample_with_temperature(self, logits, mask, temperature): self.called_temperatures.append(temperature) - return super().get_next_token(token_ids, mask, temperature) + return super().sample_with_temperature(logits, mask, temperature) def get_logits(self, token_ids: list[int]) -> np.ndarray: """Pretends to compute the logits for the given token state.""" diff --git a/guidance/models/_model.py b/guidance/models/_model.py index e4e08d0ff..8e02fd4fd 100644 --- a/guidance/models/_model.py +++ b/guidance/models/_model.py @@ -133,42 +133,86 @@ def __call__(self, prompt, grammar, ensure_bos_token=True) -> Iterator[EngineCal """ parser = self.start(prompt, grammar, ensure_bos_token) + has_get_logits = True token = None - while not parser.done(): - gen_data, response = parser.advance(token) - - if gen_data is not None: - if parser.is_accepting() and self.tokenizer.eos_token_id is not None: - # Whenever we are in an accepting state, we will allow the model to generate whatever it wants - # but we will treat any "illegal" tokens as EOS, allowing the model to finish gracefully. - assert gen_data.mask[self.tokenizer.eos_token_id] - token = self.get_next_token( - token_ids=gen_data.tokens, - mask=None, - temperature=gen_data.temperature - ) - if not gen_data.mask[token]: - token = self.tokenizer.eos_token_id - else: - token = self.get_next_token( - token_ids=gen_data.tokens, - mask=gen_data.mask, - temperature=gen_data.temperature - ) + while True: + tokens, mid_process_fut = parser.advance(token) + + # Note that has_pending_stop implies that the response is a stop response, + # but the converse is not true. We can therefore avoid some (but not all) + # unnecessary calls to get_logits on the final iteration. + has_pending_stop = parser.has_pending_stop() + + if has_get_logits and not has_pending_stop: + try: + logits = self.get_logits(token_ids=tokens) + except NotImplementedError: + # Permanently fall-back to get_next_token if get_logits is not implemented + has_get_logits = False + logits = None else: - token = None + logits = None + + # Important: don't wait on this future until after getting the logits; + # this allows the mask to be built concurrently with model inference + mask, ll_response = mid_process_fut.result() + + engine_response = ll_response.progress.to_engine_call_response() + yield engine_response + + if ll_response.stop: + assert mask is None + # May raise an exception if the parser is in an bad state! + parser.cleanup() + # Ensure we break AFTER yielding the final response + break + + # If there was a pending stop, we should have broken out of the loop + assert not has_pending_stop + + # Help the type checker: assert that everything we need to get the next token is not None + assert mask is not None + assert ll_response.temperature is not None + + can_finish_early = parser.is_accepting() and self.tokenizer.eos_token_id is not None + + if can_finish_early: + # Type checker needs some help + assert self.tokenizer.eos_token_id is not None + # Should be equivalent to parser.is_accepting() + assert mask[self.tokenizer.eos_token_id] + # Whenever we are in an accepting state, we will allow the model to generate whatever it wants + # but we will treat any "illegal" tokens as EOS, allowing the model to finish gracefully. + # Hence, mask must be None + mask_for_sampling = None + else: + mask_for_sampling = mask + + if logits is not None: + token = self.sample_with_temperature( + logits=logits, + mask=mask_for_sampling, + temperature=ll_response.temperature, + ) + else: + token = self.get_next_token( + tokens, + mask_for_sampling, + ll_response.temperature + ) + + if can_finish_early and not mask[token]: + # Type checker needs some help + assert self.tokenizer.eos_token_id is not None + token = self.tokenizer.eos_token_id - yield response def get_next_token(self, token_ids: list[int], mask: Optional[bytes], temperature: float) -> int: - """Base implementation for getting the next token from the model which calls get_logits and sample_with_temperature. - Subclasses may override this method, e.g. if they use external APIs that do not support getting logits directly. - """ - logits = self.get_logits(token_ids) - token = self.sample_with_temperature(logits, mask, temperature) - return token + # Prefer to implement get_logits over get_next_token as it allows for concurrent mask computation + raise NotImplementedError def get_logits(self, token_ids: list[int]) -> np.ndarray: + # Prefer to implement get_logits over get_next_token as it allows for concurrent mask computation raise NotImplementedError def sample_with_temperature(self, logits: np.ndarray, mask: Optional[bytes], temperature: float) -> int: diff --git a/setup.py b/setup.py index 45290dcf1..c8b801566 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ "referencing", "requests", "tiktoken>=0.3", - "llguidance>=0.1.7", + "llguidance>=0.3.0", ] # Our basic list of 'extras' diff --git a/tests/model_integration/test_model.py b/tests/model_integration/test_model.py index 2b6a2489c..b02b719bc 100644 --- a/tests/model_integration/test_model.py +++ b/tests/model_integration/test_model.py @@ -74,18 +74,18 @@ def test_associativity(selected_model: models.Model): REMOTE_MODELS = [models.AzureGuidance] for rm in REMOTE_MODELS: if isinstance(selected_model, rm): - pytest.skip("Method get_next_token not available for remote models") + pytest.skip("Method get_logits not available for remote models") prompt = "pi = " grammar = gen("number", regex=r"\d") engine = selected_model.engine - with patch.object(engine, "get_next_token", side_effect=engine.get_next_token) as get_next_token_1: + with patch.object(engine, "get_logits", side_effect=engine.get_logits) as get_logits_1: _ = selected_model + (prompt + grammar) - prompt_tokens_1 = get_next_token_1.call_args.kwargs["token_ids"] + prompt_tokens_1 = get_logits_1.call_args_list[0].kwargs["token_ids"] - with patch.object(engine, "get_next_token", side_effect=engine.get_next_token) as get_next_token_2: + with patch.object(engine, "get_logits", side_effect=engine.get_logits) as get_logits_2: _ = (selected_model + prompt) + grammar - prompt_tokens_2 = get_next_token_2.call_args.kwargs["token_ids"] + prompt_tokens_2 = get_logits_2.call_args_list[0].kwargs["token_ids"] # Main assertion: the prompt tokens should be the same assert prompt_tokens_1 == prompt_tokens_2