From 0f1deeeb7b8df4a96483b74d849164ad5a8ca3a5 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 17 Sep 2024 12:37:42 -0400 Subject: [PATCH] add safe_subset argument to json_schema.to_regex, implement safe get_int_pattern / get_str_pattern --- outlines/fsm/json_schema.py | 332 ++++++++++++++++++++++++++++++---- tests/fsm/test_json_schema.py | 193 ++++++++++++++++++-- 2 files changed, 466 insertions(+), 59 deletions(-) diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index 98d2de59c..72fa25f50 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -1,8 +1,10 @@ import inspect +import itertools import json +import math import re import warnings -from typing import Callable, Optional, Tuple, Type, Union +from typing import Callable, List, Optional, Tuple, Type, Union from jsonschema.protocols import Validator from pydantic import BaseModel, create_model @@ -18,14 +20,21 @@ NUMBER = rf"({INTEGER})(\.[0-9]+)?([eE][+-][0-9]+)?" BOOLEAN = r"(true|false)" NULL = r"null" -WHITESPACE = r"[ ]?" +WHITESPACE = r"[\n\t ]*" +SAFE_WHITESPACE = r"[ ]?" +SAFE_INT_MAX = int(1e19) +SAFE_INT_MIN = int(-1e19) +SAFE_STR_MAX_LEN = 256 + + +# TODO: Deprecate? This isn't used anywhere internally type_to_regex = { "string": STRING, - "integer": INTEGER, "number": NUMBER, "boolean": BOOLEAN, "null": NULL, + "integer": INTEGER, } DATE_TIME = r'"(-?(?:[1-9][0-9]*)?[0-9]{4})-(1[0-2]|0[1-9])-(3[01]|0[1-9]|[12][0-9])T(2[0-3]|[01][0-9]):([0-5][0-9]):([0-5][0-9])(\.[0-9]{3})?(Z)?"' @@ -41,7 +50,9 @@ } -def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = None): +def build_regex_from_schema( + schema: str, whitespace_pattern: Optional[str] = None, safe_subset: bool = True +): """Turn a JSON schema into a regex that matches any JSON object that follows this schema. @@ -60,6 +71,13 @@ def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = Non whitespace_pattern Pattern to use for JSON syntactic whitespace (doesn't impact string literals) Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + safe_subset + Use a subset of json schema which performs better with language models. + If you want to all the model to generate any json structure, set to False. + Changes the following: + - If whitespace_pattern is None, sets whitespace pattern to WHITESPACE (r"[ ]?") + - If unconstrained integer is used, constrain integer to *roughly* the int64 range [-1e19, 1e19] + - If unconstrained string is used, constrain it to max of 256 characters Returns ------- @@ -83,7 +101,7 @@ def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = Non resolver = registry.resolver() content = schema.contents - return to_regex(resolver, content, whitespace_pattern) + return to_regex(resolver, content, whitespace_pattern, safe_subset) def convert_json_schema_to_str(json_schema: Union[dict, str, Type[BaseModel]]) -> str: @@ -173,7 +191,10 @@ def validate_quantifiers( def to_regex( - resolver: Resolver, instance: dict, whitespace_pattern: Optional[str] = None + resolver: Resolver, + instance: dict, + whitespace_pattern: Optional[str] = None, + safe_subset: bool = True, ): """Translate a JSON Schema instance into a regex that validates the schema. @@ -196,11 +217,18 @@ def to_regex( whitespace_pattern Pattern to use for JSON syntactic whitespace (doesn't impact string literals) Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + safe_subset + Use a subset of json schema which performs better with language models. + If you want to all the model to generate any json structure, set to False. + Changes the following: + - If whitespace_pattern is None, sets whitespace pattern to WHITESPACE (r"[ ]?") + - If unconstrained integer is used, constrain integer to *roughly* the int64 range [-1e19, 1e19] + - If unconstrained string is used, constrain it to max of 256 characters """ # set whitespace pattern if whitespace_pattern is None: - whitespace_pattern = WHITESPACE + whitespace_pattern = SAFE_WHITESPACE if safe_subset else WHITESPACE if instance == {}: # JSON Schema Spec: Empty object means unconstrained, any json type is legal @@ -213,7 +241,9 @@ def to_regex( {"type": "array"}, {"type": "object"}, ] - regexes = [to_regex(resolver, t, whitespace_pattern) for t in types] + regexes = [ + to_regex(resolver, t, whitespace_pattern, safe_subset) for t in types + ] regexes = [rf"({r})" for r in regexes] return rf"{'|'.join(regexes)}" @@ -231,7 +261,7 @@ def to_regex( last_required_pos = max([i for i, value in enumerate(is_required) if value]) for i, (name, value) in enumerate(properties.items()): subregex = f'{whitespace_pattern}"{re.escape(name)}"{whitespace_pattern}:{whitespace_pattern}' - subregex += to_regex(resolver, value, whitespace_pattern) + subregex += to_regex(resolver, value, whitespace_pattern, safe_subset) if i < last_required_pos: subregex = f"{subregex}{whitespace_pattern}," elif i > last_required_pos: @@ -245,7 +275,7 @@ def to_regex( property_subregexes = [] for i, (name, value) in enumerate(properties.items()): subregex = f'{whitespace_pattern}"{name}"{whitespace_pattern}:{whitespace_pattern}' - subregex += to_regex(resolver, value, whitespace_pattern) + subregex += to_regex(resolver, value, whitespace_pattern, safe_subset) property_subregexes.append(subregex) possible_patterns = [] for i in range(len(property_subregexes)): @@ -266,7 +296,8 @@ def to_regex( # given subschemas. elif "allOf" in instance: subregexes = [ - to_regex(resolver, t, whitespace_pattern) for t in instance["allOf"] + to_regex(resolver, t, whitespace_pattern, safe_subset) + for t in instance["allOf"] ] subregexes_str = [f"{subregex}" for subregex in subregexes] return rf"({''.join(subregexes_str)})" @@ -275,7 +306,8 @@ def to_regex( # any (one or more) of the given subschemas. elif "anyOf" in instance: subregexes = [ - to_regex(resolver, t, whitespace_pattern) for t in instance["anyOf"] + to_regex(resolver, t, whitespace_pattern, safe_subset) + for t in instance["anyOf"] ] return rf"({'|'.join(subregexes)})" @@ -283,7 +315,8 @@ def to_regex( # one of the given subschemas. elif "oneOf" in instance: subregexes = [ - to_regex(resolver, t, whitespace_pattern) for t in instance["oneOf"] + to_regex(resolver, t, whitespace_pattern, safe_subset) + for t in instance["oneOf"] ] xor_patterns = [f"(?:{subregex})" for subregex in subregexes] @@ -293,7 +326,8 @@ def to_regex( # Create pattern for Tuples, per JSON Schema spec, `prefixItems` determines types at each idx elif "prefixItems" in instance: element_patterns = [ - to_regex(resolver, t, whitespace_pattern) for t in instance["prefixItems"] + to_regex(resolver, t, whitespace_pattern, safe_subset) + for t in instance["prefixItems"] ] comma_split_pattern = rf"{whitespace_pattern},{whitespace_pattern}" tuple_inner = comma_split_pattern.join(element_patterns) @@ -321,7 +355,7 @@ def to_regex( elif "$ref" in instance: path = f"{instance['$ref']}" instance = resolver.lookup(path).contents - return to_regex(resolver, instance, whitespace_pattern) + return to_regex(resolver, instance, whitespace_pattern, safe_subset) # The type keyword may either be a string or an array: # - If it's a string, it is the name of one of the basic types. @@ -332,16 +366,11 @@ def to_regex( instance_type = instance["type"] if instance_type == "string": if "maxLength" in instance or "minLength" in instance: - max_items = instance.get("maxLength", "") - min_items = instance.get("minLength", "") - try: - if int(max_items) < int(min_items): - raise ValueError( - "maxLength must be greater than or equal to minLength" - ) # FIXME this raises an error but is caught right away by the except (meant for int("") I assume) - except ValueError: - pass - return f'"{STRING_INNER}{{{min_items},{max_items}}}"' + return get_str_pattern( + min_length=instance.get("minLength"), + max_length=instance.get("maxLength"), + safe_subset=safe_subset, + ) elif "pattern" in instance: pattern = instance["pattern"] if pattern[0] == "^" and pattern[-1] == "$": @@ -363,9 +392,11 @@ def to_regex( f"Format {format} is not supported by Outlines" ) else: - return type_to_regex["string"] + return get_str_pattern(safe_subset=safe_subset) elif instance_type == "number": + # TODO: implement actualy json schema spec parameters: "maximum" and "minimum", + # should be easy through extending get_int_range_pattern bounds = { "minDigitsInteger", "maxDigitsInteger", @@ -402,15 +433,24 @@ def to_regex( else "+" ) return rf"((-)?(0|[1-9][0-9]{integers_quantifier}))(\.[0-9]{fraction_quantifier})?([eE][+-][0-9]{exponent_quantifier})?" - return type_to_regex["number"] + return NUMBER elif instance_type == "integer": - if "minDigits" in instance or "maxDigits" in instance: - min_digits, max_digits = validate_quantifiers( - instance.get("minDigits"), instance.get("maxDigits"), start_offset=1 + # TODO: Remove errors eventulaly - these keys aren't part of json schema spec + if "maxDigits" in instance: + raise ValueError( + "'maxDigits' is not supported. Please use 'minimum' instead." + ) + if "minDigits" in instance: + raise ValueError( + "'minDigits' is not supported. Please use 'minimum' instead." ) - return rf"(-)?(0|[1-9][0-9]{{{min_digits},{max_digits}}})" - return type_to_regex["integer"] + + return get_int_pattern( + minimum=instance.get("minimum"), + maximum=instance.get("maximum"), + safe_subset=safe_subset, + ) elif instance_type == "array": num_repeats = _get_num_items_pattern( @@ -422,7 +462,9 @@ def to_regex( allow_empty = "?" if int(instance.get("minItems", 0)) == 0 else "" if "items" in instance: - items_regex = to_regex(resolver, instance["items"], whitespace_pattern) + items_regex = to_regex( + resolver, instance["items"], whitespace_pattern, safe_subset + ) return rf"\[{whitespace_pattern}(({items_regex})(,{whitespace_pattern}({items_regex})){num_repeats}){allow_empty}{whitespace_pattern}\]" else: # Here we need to make the choice to exclude generating list of objects @@ -441,7 +483,8 @@ def to_regex( legal_types.append({"type": "array", "depth": depth - 1}) regexes = [ - to_regex(resolver, t, whitespace_pattern) for t in legal_types + to_regex(resolver, t, whitespace_pattern, safe_subset) + for t in legal_types ] return rf"\[{whitespace_pattern}({'|'.join(regexes)})(,{whitespace_pattern}({'|'.join(regexes)})){num_repeats}{allow_empty}{whitespace_pattern}\]" @@ -481,11 +524,12 @@ def to_regex( legal_types.append({"type": "array", "depth": depth - 1}) additional_properties = {"anyOf": legal_types} + key_pattern = get_str_pattern(safe_subset=safe_subset) value_pattern = to_regex( - resolver, additional_properties, whitespace_pattern + resolver, additional_properties, whitespace_pattern, safe_subset ) key_value_pattern = ( - f"{STRING}{whitespace_pattern}:{whitespace_pattern}{value_pattern}" + f"{key_pattern}{whitespace_pattern}:{whitespace_pattern}{value_pattern}" ) key_value_successor_pattern = ( f"{whitespace_pattern},{whitespace_pattern}{key_value_pattern}" @@ -501,17 +545,17 @@ def to_regex( ) elif instance_type == "boolean": - return type_to_regex["boolean"] + return BOOLEAN elif instance_type == "null": - return type_to_regex["null"] + return NULL elif isinstance(instance_type, list): # Here we need to make the choice to exclude generating an object # if the specification of the object is not give, even though a JSON # object that contains an object here would be valid under the specification. regexes = [ - to_regex(resolver, {"type": t}, whitespace_pattern) + to_regex(resolver, {"type": t}, whitespace_pattern, safe_subset) for t in instance_type if t != "object" ] @@ -550,3 +594,213 @@ def get_schema_from_signature(fn: Callable) -> str: model = create_model(fn_name, **arguments) return model.model_json_schema() + + +def get_subranges(minimum: int, maximum: int) -> List[Tuple]: + """ + Convert a range into a list of subranges which can fit into a pattern + + E.g. minimum=123, maximum=456 cannot easily be made into a regex pattern + therefore, (123, 456) is converted to + [(123, 129), (130, 199), (200, 399), (400, 449), (450, 456)] + which can be converted in get_subrange_pattern() to + ["12[3-9]", "(1[3-9][0-9]{1}", "[2-3][0-9]{2}", "4[0-4][0-9]{1}", "45[0-6]"] + """ + min_str = str(minimum).zfill(len(str(maximum))) + max_str = str(maximum) + + # if only the last digit varies, its a valid subrange + if min_str[:-1] == max_str[:-1]: + return [(minimum, maximum)] + + # calculate the shared prefix between minimum and maximum and left-truncate it for now + num_shared_prefix = len( + list(itertools.takewhile(lambda x: x[0] == x[1], zip(min_str, max_str))) + ) + shared_min = min_str[num_shared_prefix:] + shared_max = max_str[num_shared_prefix:] + prefix = min_str[:num_shared_prefix] + + # determine how many trailing digits back are valid [0-9] + # set first digit which doesn't qualify as the flex + # then combine: {prefix}{flex}[0-9]{count} + num_truncate = len(shared_min) - len(shared_min.rstrip("0")) + 1 + child_max = int(prefix + shared_min[:-num_truncate] + "9" * num_truncate) + if child_max > maximum: + child_max = int(prefix + shared_max[0] + "0" * len(shared_max[1:])) - 1 + + if child_max == maximum: + return [(minimum, child_max)] + return [(minimum, child_max)] + get_subranges(child_max + 1, maximum) + + +def get_subrange_pattern(minimum: int, maximum: int) -> str: + """ + Generates a regex pattern for a subrange where digits can be represented using character classes. + + This function creates a regex pattern for a given integer subrange where the digits can be + represented using character classes and quantifiers. It assumes that the range can be represented + by varying specific digits while others remain constant or within a simple range. + + For example: + - (200, 399) -> '([2-3][0-9]{2})' + - (310, 319) -> '(31[0-9])' + - (100, 189) -> '(1[0-8][0-9])' + + The function should only be called with ranges that can be represented in this way. + It does not handle ranges where digits do not align for simple character classes. + + Args: + minimum (int): The lower bound of the integer subrange. + maximum (int): The upper bound of the integer subrange. + + Returns: + str: A regex pattern string that matches all integers in the subrange. + """ + + max_str = str(maximum) + min_str = str(minimum).zfill(len(max_str)) + + last_range_zero = len(min_str) - re.search(r"[^0]|$", min_str[::-1]).start() # type: ignore + last_range_nine = len(max_str) - re.search(r"[^9]|$", max_str[::-1]).start() # type: ignore + if last_range_zero is None or last_range_nine is None: + raise RuntimeError(f"invalid string range: {minimum} to {maximum}") + full_range_start = max(last_range_zero, last_range_nine) + + shared_prefix = min_str[: full_range_start - 1] + range_digit_min, range_digit_max = ( + min_str[full_range_start - 1], + max_str[full_range_start - 1], + ) + + pattern = rf"{shared_prefix}[{range_digit_min}-{range_digit_max}]" + + num_0_9_chars = len(max_str) - full_range_start + if num_0_9_chars: + pattern += rf"[0-9]{{{num_0_9_chars}}}" + + return rf"({pattern})" + + +def get_positive_int_range_pattern(minimum: int, maximum: int) -> str: + """ + Generates a regex pattern for positive integers within a specified range. + + This function creates a regex pattern that matches positive integers from `minimum` to `maximum`. + It handles ranges with finite and infinite upper bounds, and can include zero explicitly if + needed. + + The function splits the range into subranges suitable for pattern generation, and combines + the patterns for each subrange using alternation (the '|' operator). + + Args: + minimum (int or inf): The lower bound of the integer range (must be >= 0). + maximum (int or inf): The upper bound of the integer range (must be >= 0 or infinity). + + Returns: + str: A regex pattern string that matches all positive integers in the range. + """ + assert minimum >= 0 + assert maximum >= 0 + + # Handle the case where zero needs to be included explicitly. + if minimum == 0 and maximum == 0: + return "(0)" # Return a pattern that matches zero. + elif minimum == 0: + minimum = 1 + explicit_zero = True # Flag to include OR Zero (`|0`) in the final pattern. + else: + explicit_zero = False + + if maximum == float("inf"): + # Handle infinite upper bound. + # Create and OR two patterns: (minimum, lower_maximum - 1) | (lower_maximum, infinity) + lower_maximum = 10 ** math.ceil(math.log10(minimum + 1)) - 1 + lower_maximum_pattern = "|".join( + [ + get_subrange_pattern(sub_min, sub_max) + for sub_min, sub_max in get_subranges(minimum, lower_maximum) + ] + ) + lower_max_to_infinity_pattern = rf"[\d]{{{len(str(lower_maximum))+1},}}" + pattern = f"({lower_max_to_infinity_pattern}|{lower_maximum_pattern})" + else: + pattern = "|".join( + [ + get_subrange_pattern(sub_min, sub_max) + for sub_min, sub_max in get_subranges(minimum, maximum) + ] + ) + + if explicit_zero: + pattern = rf"(0|({pattern}))" + + return pattern + + +def get_int_pattern(minimum=None, maximum=None, safe_subset: bool = False) -> str: + """ + This function generates a regex pattern that matches integers from `minimum` to `maximum`, + inclusive. It handles negative ranges, positive ranges, zero, and ranges that span both negative + and positive numbers. + + If no bounds are specified, it defaults to matching all integers. The `safe_subset` parameter + can be used to limit the range to safe integer values (e.g., to avoid excessively large numbers). + + Args: + minimum (int, (+/-)inf, optional): The lower bound of the integer range. Defaults to negative infinity. + maximum (int, (+/-)inf, optional): The upper bound of the integer range. Defaults to positive infinity. + safe_subset (bool, optional): If True, uses SAFE_INT_MIN and SAFE_INT_MAX as default bounds. + + Returns: + str: A regex pattern string that matches all integers in the specified range. + """ + # handle safe subset of range + if minimum is None: + minimum = SAFE_INT_MIN if safe_subset else -float("inf") + if maximum is None: + maximum = SAFE_INT_MAX if safe_subset else float("inf") + + if (minimum, maximum) == (-float("inf"), float("inf")): + return INTEGER + + assert minimum <= maximum + + if minimum == maximum == 0: + pattern = "0" + elif minimum >= 0 and maximum >= 0: + pattern = get_positive_int_range_pattern(minimum, maximum) + elif minimum < 0 and maximum <= 0: + # entirely negative range: prefix with `-` and calculate abs of range + abs_pattern = get_positive_int_range_pattern(max(abs(maximum), 1), abs(minimum)) + pattern = rf"-({abs_pattern})" + if maximum == 0: + pattern = rf"0|({pattern})" + else: # minimum < 0 and maximum > 0: + # positive component of range | negative component + minimum_pattern = get_positive_int_range_pattern(1, abs(minimum)) + maximum_pattern = get_positive_int_range_pattern(0, maximum) + pattern = rf"(-({minimum_pattern}))|({maximum_pattern})" + + return rf"({pattern})" + + +def get_str_pattern( + min_length: Optional[int] = None, + max_length: Optional[int] = None, + safe_subset: bool = False, +) -> str: + if min_length is None and max_length is None and not safe_subset: + return STRING + elif min_length and max_length and int(max_length or 0) < int(min_length or 0): + raise ValueError("maxLength must be greater than or equal to minLength") + elif (min_length or 0) < 0 or (max_length or 0) < 0: + raise ValueError("minLength and maxLength must be greater than or equal to 0") + + range_begin = str(min_length) if min_length else "" + if max_length is None: + range_end = str(SAFE_STR_MAX_LEN) if safe_subset else "" + else: + range_end = str(max_length) + + return f'"{STRING_INNER}{{{range_begin},{range_end}}}"' diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index 7565ff642..e81881ace 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -1,5 +1,7 @@ import json +import random import re +import string as pystring from typing import List, Literal, Union import interegular @@ -13,16 +15,21 @@ INTEGER, NULL, NUMBER, + SAFE_WHITESPACE, STRING, STRING_INNER, TIME, UUID, - WHITESPACE, build_regex_from_schema, + get_int_pattern, get_schema_from_signature, + get_str_pattern, to_regex, ) +SAFE_INT = get_int_pattern(safe_subset=True) +SAFE_STR = get_str_pattern(safe_subset=True) + def test_function_basic(): def test_function(foo: str, bar: List[int]): @@ -71,7 +78,7 @@ class User(BaseModel): ) def test_match_integer(pattern, does_match): step = {"title": "Foo", "type": "integer"} - regex = to_regex(None, step) + regex = to_regex(None, step, safe_subset=False) assert regex == INTEGER value = pattern["integer"] @@ -83,6 +90,100 @@ def test_match_integer(pattern, does_match): assert match is None +@pytest.mark.parametrize( + "minimum,maximum", + [ + (0, 0), + (-1, 0), + (0, 1), + (-15, 0), + (0, 15), + (-1, 1), + (-15, 15), + (-1234, 56), + (-56, 1234), + (-9, 9), + (-10, 10), + (-9, 10), + (-10, 9), + (123, 199), + (123, 456), + (5600, 5678), + (550, 560), + (-12345, 3423), + (50, 10000), + ], +) +def test_int_range_pattern(minimum, maximum): + pattern = get_int_pattern(minimum, maximum) + fsm = interegular.parse_pattern(pattern).to_fsm().reduce() + pattern_numbers = {"".join(s) for s in fsm.strings()} + range_numbers = set(map(str, range(minimum, maximum + 1))) + assert pattern_numbers == range_numbers + + # logarithmic space complexity + assert len(fsm.states) <= (len(str(minimum)) + len(str(maximum))) * 2 + + +def test_int_range_unconstrained(): + # test unconstrained + pattern = get_int_pattern(float("-inf"), float("inf")) + fsm = interegular.parse_pattern(pattern).to_fsm().reduce() + assert get_int_pattern(None, None) == pattern + assert fsm.accepts("0") + assert fsm.accepts("-1") + assert fsm.accepts("1") + assert fsm.accepts("-98427983498234893274983274892") + assert fsm.accepts("2994399439493294329432984932") + + # caveat: json_schema.INTEGER allows -0, we can remove this functionality safely at any time + # assert not fsm.accepts("-0") + + assert not fsm.accepts("1.1") + assert not fsm.accepts("1.0") + assert not fsm.accepts("1.0") + assert not fsm.accepts("one") + + assert len(fsm.states) < 5 + + +def test_int_range_min_zero(): + # test min zero + pattern = get_int_pattern(0, float("inf")) + fsm = interegular.parse_pattern(pattern).to_fsm() + assert fsm.accepts("0") + assert not fsm.accepts("-1") + assert fsm.accepts("1") + assert not fsm.accepts("-98427983498234893274983274892") + assert fsm.accepts("2994399439493294329432984932") + + +def test_int_range_max_zero(): + # test min zero + pattern = get_int_pattern(-float("inf"), 0) + fsm = interegular.parse_pattern(pattern).to_fsm() + assert fsm.accepts("0") + assert fsm.accepts("-1") + assert not fsm.accepts("1") + assert fsm.accepts("-98427983498234893274983274892") + assert not fsm.accepts("2994399439493294329432984932") + + +def test_int_range_max_minus_32(): + # test min zero + pattern = get_int_pattern(-float("inf"), -32) + fsm = interegular.parse_pattern(pattern).to_fsm() + assert not fsm.accepts("0") + assert not fsm.accepts("-1") + assert not fsm.accepts("1") + assert not fsm.accepts("32") + assert fsm.accepts("-32") + assert fsm.accepts("-33") + assert fsm.accepts("-39482929438") + assert fsm.accepts("-98427983498234893274983274892") + assert not fsm.accepts("2994399439493294329432984932") + + @pytest.mark.parametrize( "pattern,does_match", [ @@ -110,6 +211,56 @@ def test_match_number(pattern, does_match): assert match is None +@pytest.mark.parametrize( + "min_len,max_len,safe_subset,expected_max,errors", + [ + # if no max and no safe mode, any length allowed + (None, None, False, None, False), + # max_len is None, use safe max_len of 256 + (None, None, True, 256, False), + (0, None, True, 256, False), + # if max_len is specified, it overrides safe_subset rules + (None, 500, True, 500, False), + (0, 500, True, 500, False), + # if min_len specification has no effect + (3, 500, True, 500, False), + (30, 500, True, 500, False), + (300, 500, True, 500, False), + # illegal + (-1, None, True, None, True), + (30, 20, True, None, True), + ], +) +def test_get_str_pattern(min_len, max_len, safe_subset, expected_max, errors): + if errors: + with pytest.raises(ValueError): + get_str_pattern(min_len, max_len, safe_subset) + return + + pattern = get_str_pattern(min_len, max_len, safe_subset) + + # verify str len in (min_len, max_len) + def str_of_len(str_len): + s = "".join(random.choices(pystring.ascii_letters + pystring.digits, k=str_len)) + return f'"{s}"' + + # verify min_len held + min_len = min_len or 0 + assert re.match(pattern, str_of_len(min_len)) + if min_len != 0: + assert not re.match(pattern, str_of_len(min_len - 1)) + + # verify expected_max accurate + if expected_max is not None: + assert re.match(pattern, str_of_len(expected_max)) + assert re.match(pattern, str_of_len(max(expected_max - 1, min_len))) + assert not re.match(pattern, str_of_len(expected_max + 1)) + else: + assert re.match(pattern, str_of_len(max(min_len, 100))) + assert re.match(pattern, str_of_len(max(min_len, 1000))) + assert re.match(pattern, str_of_len(max(min_len, 100000))) + + @pytest.mark.parametrize( "schema,regex,examples", [ @@ -267,11 +418,11 @@ def test_match_number(pattern, does_match): "title": "Foo", "type": "object", "properties": { - "count": {"title": "Count", "type": "integer", "minDigits": 3} + "count": {"title": "Count", "type": "integer", "minimum": 100} }, "required": ["count"], }, - '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]{2,})[ ]?\\}', + '\\{[ ]?"count"[ ]?:[ ]?(([\\d]{4,}|([1-9][0-9]{2})))[ ]?\\}', [('{ "count": 10 }', False), ('{ "count": 100 }', True)], ), # integer with maximum digits @@ -280,14 +431,14 @@ def test_match_number(pattern, does_match): "title": "Foo", "type": "object", "properties": { - "count": {"title": "Count", "type": "integer", "maxDigits": 3} + "count": {"title": "Count", "type": "integer", "maximum": 999} }, "required": ["count"], }, - '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]{,2})[ ]?\\}', + '\\{[ ]?"count"[ ]?:[ ]?((-(([\\d]{2,}|([1-9]))))|((0|(([1-9])|([1-9][0-9]{1})|([1-9][0-9]{2})))))[ ]?\\}', [('{ "count": 100 }', True), ('{ "count": 1000 }', False)], ), - # integer with minimum and maximum digits + # integer with minimum and maximum ( { "title": "Foo", @@ -296,13 +447,13 @@ def test_match_number(pattern, does_match): "count": { "title": "Count", "type": "integer", - "minDigits": 3, - "maxDigits": 5, + "minimum": 50, + "maximum": 50000, } }, "required": ["count"], }, - '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]{2,4})[ ]?\\}', + '\\{[ ]?"count"[ ]?:[ ]?(([5-9][0-9]{1})|([1-9][0-9]{2})|([1-9][0-9]{3})|([1-4][0-9]{4})|(5000[0-0]))[ ]?\\}', [ ('{ "count": 10 }', False), ('{ "count": 100 }', True), @@ -420,7 +571,7 @@ def test_match_number(pattern, does_match): # array ( {"title": "Foo", "type": "array", "items": {"type": "number"}}, - rf"\[{WHITESPACE}(({NUMBER})(,{WHITESPACE}({NUMBER})){{0,}})?{WHITESPACE}\]", + rf"\[{SAFE_WHITESPACE}(({NUMBER})(,{SAFE_WHITESPACE}({NUMBER})){{0,}})?{SAFE_WHITESPACE}\]", [("[1e+9,1.3]", True), ("[]", True), ("[1", False)], ), # array with a set length of 1 @@ -432,7 +583,7 @@ def test_match_number(pattern, does_match): "minItems": 1, "maxItems": 1, }, - rf"\[{WHITESPACE}(({INTEGER})(,{WHITESPACE}({INTEGER})){{0,0}}){WHITESPACE}\]", + rf"\[{SAFE_WHITESPACE}(({INTEGER})(,{SAFE_WHITESPACE}({INTEGER})){{0,0}}){SAFE_WHITESPACE}\]", [("[1]", True), ("[1,2]", False), ('["a"]', False), ("[]", False)], ), # array with a set length greather than 1 @@ -444,7 +595,7 @@ def test_match_number(pattern, does_match): "minItems": 3, "maxItems": 3, }, - rf"\[{WHITESPACE}(({INTEGER})(,{WHITESPACE}({INTEGER})){{2,2}}){WHITESPACE}\]", + rf"\[{SAFE_WHITESPACE}(({INTEGER})(,{SAFE_WHITESPACE}({INTEGER})){{2,2}}){SAFE_WHITESPACE}\]", [("[1]", False), ("[]", False), ("[1,2,3]", True), ("[1,2,3,4]", False)], ), # array with length 0 @@ -456,7 +607,7 @@ def test_match_number(pattern, does_match): "minItems": 0, "maxItems": 0, }, - rf"\[{WHITESPACE}\]", + rf"\[{SAFE_WHITESPACE}\]", [("[1]", False), ("[]", True), ("[1,2,3]", False), ("[1,2,3,4]", False)], ), # object @@ -473,7 +624,7 @@ def test_match_number(pattern, does_match): }, "required": ["test_dict"], }, - rf"""\{{{WHITESPACE}"test_dict"{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{STRING}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{STRING}){{0,}})?{WHITESPACE}\}}{WHITESPACE}\}}""", + rf"""\{{{SAFE_WHITESPACE}"test_dict"{SAFE_WHITESPACE}:{SAFE_WHITESPACE}\{{{SAFE_WHITESPACE}({STRING}{SAFE_WHITESPACE}:{SAFE_WHITESPACE}{STRING}({SAFE_WHITESPACE},{SAFE_WHITESPACE}{STRING}{SAFE_WHITESPACE}:{SAFE_WHITESPACE}{STRING}){{0,}})?{SAFE_WHITESPACE}\}}{SAFE_WHITESPACE}\}}""", [ ("""{ "test_dict":{"foo":"bar","baz": "bif"}}""", True), ("""{ "test_dict":{"foo":"bar" }}""", True), @@ -499,7 +650,7 @@ def test_match_number(pattern, does_match): }, "required": ["test_dict"], }, - rf"""\{{{WHITESPACE}"test_dict"{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}){{0,}})?{WHITESPACE}\}}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}){{0,}})?{WHITESPACE}\}}){{0,}})?{WHITESPACE}\}}{WHITESPACE}\}}""", + rf"""\{{{SAFE_WHITESPACE}"test_dict"{SAFE_WHITESPACE}:{SAFE_WHITESPACE}\{{{SAFE_WHITESPACE}({STRING}{SAFE_WHITESPACE}:{SAFE_WHITESPACE}\{{{SAFE_WHITESPACE}({STRING}{SAFE_WHITESPACE}:{SAFE_WHITESPACE}{INTEGER}({SAFE_WHITESPACE},{SAFE_WHITESPACE}{STRING}{SAFE_WHITESPACE}:{SAFE_WHITESPACE}{INTEGER}){{0,}})?{SAFE_WHITESPACE}\}}({SAFE_WHITESPACE},{SAFE_WHITESPACE}{STRING}{SAFE_WHITESPACE}:{SAFE_WHITESPACE}\{{{SAFE_WHITESPACE}({STRING}{SAFE_WHITESPACE}:{SAFE_WHITESPACE}{INTEGER}({SAFE_WHITESPACE},{SAFE_WHITESPACE}{STRING}{SAFE_WHITESPACE}:{SAFE_WHITESPACE}{INTEGER}){{0,}})?{SAFE_WHITESPACE}\}}){{0,}})?{SAFE_WHITESPACE}\}}{SAFE_WHITESPACE}\}}""", [ ( """{"test_dict": {"foo": {"bar": 123, "apple": 99}, "baz": {"bif": 456}}}""", @@ -559,7 +710,7 @@ def test_match_number(pattern, does_match): "title": "Foo", "prefixItems": [{"type": "string"}, {"type": "integer"}], }, - rf"\[{WHITESPACE}{STRING}{WHITESPACE},{WHITESPACE}{INTEGER}{WHITESPACE}\]", + rf"\[{SAFE_WHITESPACE}{STRING}{SAFE_WHITESPACE},{SAFE_WHITESPACE}{INTEGER}{SAFE_WHITESPACE}\]", [('["a", 1]', True), ('["a", 1, 1]', False), ("[]", False)], ), # Nested schema @@ -751,7 +902,9 @@ def test_match_number(pattern, does_match): def test_match(schema, regex, examples): interegular.parse_pattern(regex) schema = json.dumps(schema) - test_regex = build_regex_from_schema(schema) + test_regex = build_regex_from_schema( + schema, whitespace_pattern=SAFE_WHITESPACE, safe_subset=False + ) assert test_regex == regex for string, does_match in examples: @@ -1000,10 +1153,10 @@ class MockModel(BaseModel): # assert any ws pattern can be used if whitespace_pattern == "abc": - build_regex_from_schema(schema, whitespace_pattern) + build_regex_from_schema(schema, whitespace_pattern, safe_subset=False) return - pattern = build_regex_from_schema(schema, whitespace_pattern) + pattern = build_regex_from_schema(schema, whitespace_pattern, safe_subset=False) mock_result_mult_ws = ( """{ "foo" : 4, \n\n\n "bar": "baz baz baz bar"\n\n}"""