Skip to content

Commit

Permalink
Support multiple EOS token ids + fix bugs in JsonSchemaParser (#123)
Browse files Browse the repository at this point in the history
* feat: add support for multiple EOS token IDs

* test: add test cases for leading zeros without decimal points + e notation in JSON

* fix: fix Jsonschemaparser for number type and add many test cases

---------

Co-authored-by: Andrew Wang <andrewwa@nvidia.com>
  • Loading branch information
aw632 and Andrew Wang authored Jul 27, 2024
1 parent 20e0b6d commit 075fc2c
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 10 deletions.
3 changes: 2 additions & 1 deletion lmformatenforcer/integrations/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def __init__(self, token_enforcer: TokenEnforcer, eos_token_id, analyze):
self.eos_token_id = eos_token_id

def _trim(self, input):
return [x for x in input.tolist() if x != self.eos_token_id]
return [x for x in input.tolist() if x not in \
(self.eos_token_id if isinstance(self.eos_token_id, list) else [self.eos_token_id])]

def __call__(self, step: int, batch_input_ids: List[List[int]], logits: torch.Tensor) -> torch.Tensor:
for idx in range(len(batch_input_ids)):
Expand Down
26 changes: 25 additions & 1 deletion lmformatenforcer/jsonschemaparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,12 +438,16 @@ def __init__(
self.allow_floating_point = allow_floating_point
self.seen_decimal_point = False
self.seen_whitespace_after_digits = False
self.seen_exponent = False
self.seen_digit = False

def _clone(self) -> "NumberParsingState":
clone = NumberParsingState(self.root, self.allow_floating_point)
clone.parsed_string = self.parsed_string
clone.seen_decimal_point = self.seen_decimal_point
clone.seen_whitespace_after_digits = self.seen_whitespace_after_digits
clone.seen_exponent = self.seen_exponent
clone.seen_digit = self.seen_digit
return clone

def add_character(self, new_character: str) -> CharacterLevelParser:
Expand All @@ -455,7 +459,17 @@ def add_character(self, new_character: str) -> CharacterLevelParser:
self.seen_whitespace_after_digits = True
return self
if new_character == ".":
if not self.parsed_string or len(self.parsed_string) == 1:
raise LMFormatEnforcerException("Numbers cannot start with a decimal point.")
if self.seen_decimal_point:
raise LMFormatEnforcerException("Numbers cannot contain more than two decimal points.")
self.seen_decimal_point = True
elif new_character in "eE":
if self.seen_exponent or not self.seen_digit:
raise LMFormatEnforcerException("Invalid number format")
self.seen_exponent = True
elif new_character.isdigit():
self.seen_digit = True
return self

def get_allowed_characters(self) -> str:
Expand All @@ -464,13 +478,23 @@ def get_allowed_characters(self) -> str:
allowed_characters = "0123456789"
if not self.parsed_string:
allowed_characters += "-" + WHITESPACE_CHARACTERS
if self.allow_floating_point and not self.seen_decimal_point:
if self.parsed_string and len(self.parsed_string) == 1 and self.parsed_string[0] == "0":
allowed_characters = WHITESPACE_CHARACTERS
if self.parsed_string and len(self.parsed_string) == 2 and self.parsed_string == "-0":
allowed_characters = "." + WHITESPACE_CHARACTERS
if self.parsed_string and self.parsed_string[-1] in "eE":
allowed_characters += "-+"
if self.seen_digit and not self.seen_exponent:
allowed_characters += "eE"
if self.allow_floating_point and not self.seen_decimal_point and self.seen_digit and not self.seen_exponent:
allowed_characters += "."
if self.parsed_string and self.parsed_string[-1].isdigit():
allowed_characters += WHITESPACE_CHARACTERS
return allowed_characters

def can_end(self) -> bool:
if self.seen_exponent and self.parsed_string[-1] in "eE+-":
return False
return bool(self.parsed_string) and (self.parsed_string[-1].isdigit() or self.seen_whitespace_after_digits)


Expand Down
10 changes: 5 additions & 5 deletions lmformatenforcer/tokenenforcer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass, field
import sys
from typing import Callable, Dict, Hashable, List, Optional, Tuple
from typing import Callable, Dict, Hashable, List, Optional, Tuple, Union
import logging

from .exceptions import LMFormatEnforcerException
Expand All @@ -14,13 +14,13 @@ class TokenEnforcerTokenizerData:
def __init__(self,
regular_tokens: List[Tuple[int, str, bool]],
decoder: Callable[[List[int]], str],
eos_token_id: int):
eos_token_id: Union[int, List[int]]):
"""
Create the tokenizer data that the TokenEnforcer needs. This can be reused for multiple TokenEnforcers if they work with the same tokenizer.
:param regular_tokens: A list of tuples (token_id, token_string, is_new_word_token) for all the regular (not special) tokens in the tokenizer vocabulary.
Note that token_string is expected to include leading / trailing whitespaces if relevant.
:param decoder: A function that decodes a list of token ids into a string.
:param eos_token_id: The token id of the end-of-string token.
:param eos_token_id: The token id(s) of the end-of-string token(s).
"""
self.regular_tokens = regular_tokens
self.tokenizer_tree = TokenizerPrefixTree(regular_tokens)
Expand Down Expand Up @@ -95,7 +95,7 @@ def _compute_allowed_tokens(self, state_tokens: Tuple, state: 'TokenEnforcer.Out
shortcut_key = state.parser.shortcut_key()
self._collect_allowed_tokens(state.parser, self.tokenizer_tree.root, allowed_tokens, shortcut_key)
if state.parser.can_end():
allowed_tokens.append(self.eos_token_id)
allowed_tokens.extend(self.eos_token_id if isinstance(self.eos_token_id, list) else [self.eos_token_id])
if not allowed_tokens:
raise ValueError(f"Parser reached state with no allowed tokens")
# root_state = next(state for state in self.prefix_states.values() if state.parser == self.root_parser)
Expand All @@ -115,7 +115,7 @@ def _compute_allowed_tokens(self, state_tokens: Tuple, state: 'TokenEnforcer.Out
"Terminating the parser. Please open an issue at \n"
"https://github.com/noamgat/lm-format-enforcer/issues with the prefix and "
"CharacterLevelParser parameters")
state.allowed_tokens = [self.eos_token_id]
state.allowed_tokens = self.eos_token_id if isinstance(self.eos_token_id, list) else [self.eos_token_id]

def _collect_allowed_tokens(self, parser: CharacterLevelParser, tree_node: TokenizerPrefixTreeNode, allowed_tokens: List[int], shortcut_key: Optional[Hashable]):
allowed_tokens.extend(tree_node.tokens)
Expand Down
6 changes: 3 additions & 3 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def assert_parser_with_string_token_enforcer(string: str, parser: CharacterLevel
# the parser the most.
target_token_array = _tokenizer.encode(prompt + string)
eos_token_id = _tokenizer.eos_token_id
if eos_token_id is None:
raise ValueError("Tokenizer does not have an EOS token")
if not eos_token_id:
raise ValueError(f"Tokenizer does not have {'an EOS token' if eos_token_id is None else 'EOS tokens'}")

token_enforcer = TokenEnforcer(_tokenizer_data, parser)
# The token enforcer is stateful - it keeps track of the parsing state as tokens arrive on a token by token basis.
Expand All @@ -82,7 +82,7 @@ def assert_parser_with_string_token_enforcer(string: str, parser: CharacterLevel
return # Test success
else:
# Reached the end of the sequence, check that ending state matches expected ending state
can_end = eos_token_id in allowed_tokens
can_end = any(token in allowed_tokens for token in (eos_token_id if isinstance(eos_token_id, list) else [eos_token_id]))
if can_end and not expect_success:
raise ValueError("Parser succeeded when it should have failed")
if not can_end and expect_success:
Expand Down
102 changes: 102 additions & 0 deletions tests/test_jsonschemaparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,3 +673,105 @@ def test_top_level_object_inheritance():
valid_object = '{"parent": {"child": "test"}}'
_test_json_schema_parsing_with_string(valid_object, schema, True)


class NumberSchema(BaseModel):
value: float = Field(..., type="number")


schema = NumberSchema.model_json_schema()

@pytest.mark.parametrize("test_input", [
'{"value": 0}',
'{"value": 1}',
'{"value": 10}',
'{"value": 0.1}',
'{"value": 1.01}',
'{"value": -1}',
'{"value": -0.1}',
'{"value": 1e5}',
'{"value": 1.5e-5}',
'{"value": 1.5e5}',
'{"value": 1.5e+5}',
'{"value": -1.5e5}',
'{"value": -1.5e-5}',
'{"value": -1.5e+5}',
'{"value": 0.0}',
'{"value": -0.0}',
'{"value": 1.0}',
'{"value": -1.0}',
'{"value": 1.5e0}',
'{"value": -1.5e0}',
'{"value": 9007199254740991}',
'{"value": -9007199254740991}',
'{"value": 1e-323}',
'{"value": 1.7976931348623157e+308}',
'{"value": 5e-324}',
'{"value": 2.2250738585072014e-308}',
])
def test_valid_number_formats(test_input):
_test_json_schema_parsing_with_string(test_input, schema, True)


@pytest.mark.parametrize("test_input", [
'{"value": 01}',
'{"value": 00.1}',
'{"value": 01.01}',
'{"value": -01}',
'{"value": -00.1}',
'{"value": 01e5}',
'{"value": 00}',
'{"value": 00.0}',
'{"value": 00.0e5}',
'{"value": -00.0e5}',
'{"value": 0123}',
'{"value": -0123}',
'{"value": 01.23e45}',
])
def test_invalid_number_formats_with_leading_zeros(test_input):
_test_json_schema_parsing_with_string(test_input, schema, False)


@pytest.mark.parametrize("test_input, expected_success", [
('{"value": .1}', False),
('{"value": -.1}', False),
('{"value": 1.}', False),
('{"value": +1}', False),
('{"value": 1e}', False),
('{"value": 1e+}', False),
('{"value": .}', False),
('{"value": -.}', False),
('{"value": e5}', False),
('{"value": .e5}', False),
('{"value": -.e5}', False),
('{"value": 1.5e}', False),
('{"value": 1.5e+}', False),
('{"value": -1.5e}', False),
('{"value": -1.5e+}', False),
('{"value": 1.5e-}', False),
('{"value": -1.5e-}', False),
('{"value": 1e-}', False),
('{"value": -1e-}', False),
('{"value": 1e+1e2}', False),
('{"value": 1e1.5}', False),
('{"value": 1e-1.5}', False),
('{"value": 1e1a}', False),
('{"value": 1e-1a}', False),
('{"value": 0x123}', False),
('{"value": 0b1010}', False),
('{"value": 0o123}', False),
('{"value": Infinity}', False),
('{"value": -Infinity}', False),
('{"value": NaN}', False),
('{"value": 1,000}', False),
('{"value": 1_000}', False),
('{"value": 1.2.3}', False),
('{"value": 1e2e3}', False),
('{"value": 1e+2e-3}', False),
('{"value": --1}', False),
('{"value": ++1}', False),
('{"value": +-1}', False),
('{"value": 9007199254740992}', True),
('{"value": -9007199254740992}', True),
])
def test_number_edge_cases(test_input, expected_success):
_test_json_schema_parsing_with_string(test_input, schema, expected_success)

0 comments on commit 075fc2c

Please sign in to comment.