Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

json: support integer minimum, maximum, exclusiveMinimum, exclusiveMaximum #7797

Merged
merged 30 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
057bbdc
json: support minimum for positive integer values
ochafik Apr 30, 2024
d69ccb0
json: fix min 0
ochafik Apr 30, 2024
c37c484
json: min + max integer constraints
ochafik May 1, 2024
af63f4f
json: handle negative min / max integer bounds
ochafik May 1, 2024
a381deb
json: fix missing paren min/max bug
ochafik May 1, 2024
f8db478
json: proper paren fix
ochafik May 1, 2024
5a86c6f
json: integration test for schemas
ochafik May 18, 2024
431edb8
json: fix bounds tests
ochafik May 18, 2024
b6b6a6c
Update json-schema-to-grammar.cpp
ochafik May 19, 2024
a786c03
Merge remote-tracking branch 'origin/master' into json-bounds2
ochafik Jun 8, 2024
931b543
json: fix negative max
ochafik Jun 8, 2024
4c1c293
json: fix negative min (w/ more than 1 digit)
ochafik Jun 8, 2024
ac2a8f8
Update test-grammar-integration.cpp
ochafik Jun 8, 2024
3549702
json: nit: move string rules together
ochafik Jun 8, 2024
e933680
json: port min/max integer support to Python & JS
ochafik Jun 8, 2024
a0f1904
nit: move + rename _build_min_max_int
ochafik Jun 8, 2024
dcc27d1
fix min in [1, 9]
ochafik Jun 9, 2024
d1f6791
Update test-grammar-integration.cpp
ochafik Jun 9, 2024
cad377d
add C++11-compatible replacement for std::string_view
ochafik Jun 9, 2024
d6483a9
add min/max constrained int field to pydantic json schema example
ochafik Jun 10, 2024
f03e9b9
Merge remote-tracking branch 'origin/master' into json-bounds2
ochafik Jun 12, 2024
6fa7364
Merge remote-tracking branch 'origin/master' into json-bounds2
ochafik Jun 22, 2024
948e55e
fix merge
ochafik Jun 22, 2024
670d5a6
json: add integration tests for min/max bounds
ochafik Jun 22, 2024
9fb8a75
Merge remote-tracking branch 'origin/master' into json-bounds2
ochafik Jun 23, 2024
d7d957d
Merge remote-tracking branch 'origin/master' into json-bounds2
ochafik Jun 24, 2024
3a80d1e
reshuffle/merge min/max integ test cases
ochafik Jun 24, 2024
09a9b75
nits / cleanups
ochafik Jun 24, 2024
48f417d
Merge remote-tracking branch 'origin/master' into json-bounds2
ochafik Jun 25, 2024
36bf003
defensive code against string out of bounds (apparently different beh…
ochafik Jun 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 245 additions & 1 deletion common/json-schema-to-grammar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,233 @@ static std::string build_repetition(const std::string & item_rule, int min_items
return result;
}

/* Minimalistic replacement for std::string_view, which is only available from C++17 onwards */
class string_view {
const std::string & _str;
const size_t _start;
const size_t _end;
public:
string_view(const std::string & str, size_t start = 0, size_t end = std::string::npos) : _str(str), _start(start), _end(end == std::string::npos ? str.length() : end) {}

size_t size() const {
return _end - _start;
}

size_t length() const {
return size();
}

operator std::string() const {
return str();
}

std::string str() const {
return _str.substr(_start, _end - _start);
}

string_view substr(size_t pos, size_t len = std::string::npos) const {
return string_view(_str, _start + pos, len == std::string::npos ? _end : _start + pos + len);
}

char operator[](size_t pos) const {
auto index = _start + pos;
if (index >= _end) {
throw std::out_of_range("string_view index out of range");
}
return _str[_start + pos];
}

bool operator==(const string_view & other) const {
std::string this_str = *this;
std::string other_str = other;
return this_str == other_str;
}
};

static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
auto has_min = min_value != std::numeric_limits<int>::min();
auto has_max = max_value != std::numeric_limits<int>::max();

auto digit_range = [&](char from, char to) {
out << "[";
if (from == to) {
out << from;
} else {
out << from << "-" << to;
}
out << "]";
};
auto more_digits = [&](int min_digits, int max_digits) {
out << "[0-9]";
if (min_digits == max_digits && min_digits == 1) {
return;
}
out << "{";
out << min_digits;
if (max_digits != min_digits) {
out << ",";
if (max_digits != std::numeric_limits<int>::max()) {
out << max_digits;
}
}
out << "}";
};
std::function<void(const string_view &, const string_view &)> uniform_range =
[&](const string_view & from, const string_view & to) {
size_t i = 0;
while (i < from.length() && i < to.length() && from[i] == to[i]) {
i++;
}
if (i > 0) {
out << "\"" << from.substr(0, i).str() << "\"";
}
if (i < from.length() && i < to.length()) {
if (i > 0) {
out << " ";
}
auto sub_len = from.length() - i - 1;
if (sub_len > 0) {
auto from_sub = from.substr(i + 1);
auto to_sub = to.substr(i + 1);
auto sub_zeros = repeat("0", sub_len);
auto sub_nines = repeat("9", sub_len);

auto to_reached = false;
out << "(";
if (from_sub == sub_zeros) {
digit_range(from[i], to[i] - 1);
out << " ";
more_digits(sub_len, sub_len);
} else {
out << "[" << from[i] << "] ";
out << "(";
uniform_range(from_sub, sub_nines);
out << ")";
if (from[i] < to[i] - 1) {
out << " | ";
if (to_sub == sub_nines) {
digit_range(from[i] + 1, to[i]);
to_reached = true;
} else {
digit_range(from[i] + 1, to[i] - 1);
}
out << " ";
more_digits(sub_len, sub_len);
}
}
if (!to_reached) {
out << " | ";
digit_range(to[i], to[i]);
out << " ";
uniform_range(sub_zeros, to_sub);
}
out << ")";
} else {
out << "[" << from[i] << "-" << to[i] << "]";
}
}
};

if (has_min && has_max) {
if (min_value < 0 && max_value < 0) {
out << "\"-\" (";
_build_min_max_int(-max_value, -min_value, out, decimals_left, /* top_level= */ true);
out << ")";
return;
}

if (min_value < 0) {
out << "\"-\" (";
_build_min_max_int(0, -min_value, out, decimals_left, /* top_level= */ true);
out << ") | ";
min_value = 0;
}

auto min_s = std::to_string(min_value);
auto max_s = std::to_string(max_value);
auto min_digits = min_s.length();
auto max_digits = max_s.length();

for (auto digits = min_digits; digits < max_digits; digits++) {
uniform_range(min_s, repeat("9", digits));
min_s = "1" + repeat("0", digits);
out << " | ";
}
uniform_range(min_s, max_s);
return;
}

auto less_decimals = std::max(decimals_left - 1, 1);

if (has_min) {
if (min_value < 0) {
out << "\"-\" (";
_build_min_max_int(std::numeric_limits<int>::min(), -min_value, out, decimals_left, /* top_level= */ false);
out << ") | [0] | [1-9] ";
more_digits(0, decimals_left - 1);
} else if (min_value == 0) {
if (top_level) {
out << "[0] | [1-9] ";
more_digits(0, less_decimals);
} else {
more_digits(1, decimals_left);
}
} else if (min_value <= 9) {
char c = '0' + min_value;
auto range_start = top_level ? '1' : '0';
if (c > range_start) {
digit_range(range_start, c - 1);
out << " ";
more_digits(1, less_decimals);
out << " | ";
}
digit_range(c, '9');
out << " ";
more_digits(0, less_decimals);
} else {
auto min_s = std::to_string(min_value);
auto len = min_s.length();
auto c = min_s[0];

if (c > '1') {
digit_range(top_level ? '1' : '0', c - 1);
out << " ";
more_digits(len, less_decimals);
out << " | ";
}
digit_range(c, c);
out << " (";
_build_min_max_int(std::stoi(min_s.substr(1)), std::numeric_limits<int>::max(), out, less_decimals, /* top_level= */ false);
out << ")";
if (c < '9') {
out << " | ";
digit_range(c + 1, '9');
out << " ";
more_digits(len - 1, less_decimals);
}
}
return;
}

if (has_max) {
if (max_value >= 0) {
if (top_level) {
out << "\"-\" [1-9] ";
more_digits(0, less_decimals);
out << " | ";
}
_build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true);
} else {
out << "\"-\" (";
_build_min_max_int(-max_value, std::numeric_limits<int>::max(), out, decimals_left, /* top_level= */ false);
out << ")";
}
return;
}

throw std::runtime_error("At least one of min_value or max_value must be set");
}

const std::string SPACE_RULE = "| \" \" | \"\\n\" [ \\t]{0,20}";

struct BuiltinRule {
Expand Down Expand Up @@ -160,7 +387,6 @@ static std::string format_literal(const std::string & literal) {
return "\"" + escaped + "\"";
}


class SchemaConverter {
private:
std::function<json(const std::string &)> _fetch_json;
Expand Down Expand Up @@ -686,6 +912,24 @@ class SchemaConverter {
int min_len = schema.contains("minLength") ? schema["minLength"].get<int>() : 0;
int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
} else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
int min_value = std::numeric_limits<int>::min();
int max_value = std::numeric_limits<int>::max();
if (schema.contains("minimum")) {
min_value = schema["minimum"].get<int>();
} else if (schema.contains("exclusiveMinimum")) {
min_value = schema["exclusiveMinimum"].get<int>() + 1;
}
if (schema.contains("maximum")) {
max_value = schema["maximum"].get<int>();
} else if (schema.contains("exclusiveMaximum")) {
max_value = schema["exclusiveMaximum"].get<int>() - 1;
}
std::stringstream out;
out << "(";
_build_min_max_int(min_value, max_value, out);
out << ") space";
return _add_rule(rule_name, out.str());
} else if (schema.empty() || schema_type == "object") {
return _add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object")));
} else {
Expand Down
1 change: 1 addition & 0 deletions examples/json-schema-pydantic-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class QAPair(BaseModel):
question: str
concise_answer: str
justification: str
stars: Annotated[int, Field(ge=1, le=5)]

class PyramidalSummary(BaseModel):
title: str
Expand Down
Loading
Loading