Skip to content

Commit

Permalink
fix(bigquery): Make JSONPathTokenizer more lenient for new standards (#…
Browse files Browse the repository at this point in the history
…4447)

* fix(bigquery): Make JSONPathTokenizer more lenient for new standards

* Mutate attr instead of Generator flags

* PR Feedback 1

* Switch to non-Tokenizer solution

* Add comment to parse_var_text

---------

Co-authored-by: George Sittas <giwrgos.sittas@gmail.com>
  • Loading branch information
VaggelisD and georgesittas authored Nov 27, 2024
1 parent 954d8fd commit 73afd0f
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 14 deletions.
25 changes: 25 additions & 0 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@
logger = logging.getLogger("sqlglot")


JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar, exp.JSONExtractArray]

DQUOTES_ESCAPING_JSON_FUNCTIONS = ("JSON_QUERY", "JSON_VALUE", "JSON_QUERY_ARRAY")


def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Values) -> str:
if not expression.find_ancestor(exp.From, exp.Join):
return self.values_sql(expression)
Expand Down Expand Up @@ -324,6 +329,23 @@ def _build_contains_substring(args: t.List) -> exp.Contains | exp.Anonymous:
return exp.Contains(this=this, expression=expr)


def _json_extract_sql(self: BigQuery.Generator, expression: JSON_EXTRACT_TYPE) -> str:
name = (expression._meta and expression.meta.get("name")) or expression.sql_name()
upper = name.upper()

dquote_escaping = upper in DQUOTES_ESCAPING_JSON_FUNCTIONS

if dquote_escaping:
self._quote_json_path_key_using_brackets = False

sql = rename_func(upper)(self, expression)

if dquote_escaping:
self._quote_json_path_key_using_brackets = True

return sql


class BigQuery(Dialect):
WEEK_OFFSET = -1
UNNEST_COLUMN_ONLY = True
Expand Down Expand Up @@ -869,6 +891,9 @@ class Generator(generator.Generator):
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
exp.Int64: rename_func("INT64"),
exp.JSONExtract: _json_extract_sql,
exp.JSONExtractArray: _json_extract_sql,
exp.JSONExtractScalar: _json_extract_sql,
exp.JSONFormat: rename_func("TO_JSON_STRING"),
exp.Levenshtein: _levenshtein_sql,
exp.Max: max_or_greatest,
Expand Down
11 changes: 9 additions & 2 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,7 @@ class Generator(metaclass=_Generator):
"_next_name",
"_identifier_start",
"_identifier_end",
"_quote_json_path_key_using_brackets",
)

def __init__(
Expand Down Expand Up @@ -706,6 +707,8 @@ def __init__(
self._identifier_start = self.dialect.IDENTIFIER_START
self._identifier_end = self.dialect.IDENTIFIER_END

self._quote_json_path_key_using_brackets = True

def generate(self, expression: exp.Expression, copy: bool = True) -> str:
"""
Generates the SQL string corresponding to the given syntax tree.
Expand Down Expand Up @@ -2871,7 +2874,7 @@ def json_path_part(self, expression: int | str | exp.JSONPathPart) -> str:
if isinstance(expression, int):
return str(expression)

if self.JSON_PATH_SINGLE_QUOTE_ESCAPE:
if self._quote_json_path_key_using_brackets and self.JSON_PATH_SINGLE_QUOTE_ESCAPE:
escaped = expression.replace("'", "\\'")
escaped = f"\\'{expression}\\'"
else:
Expand Down Expand Up @@ -4072,7 +4075,11 @@ def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str:
return f".{this}"

this = self.json_path_part(this)
return f"[{this}]" if self.JSON_PATH_BRACKETED_KEY_SUPPORTED else f".{this}"
return (
f"[{this}]"
if self._quote_json_path_key_using_brackets and self.JSON_PATH_BRACKETED_KEY_SUPPORTED
else f".{this}"
)

def _jsonpathsubscript_sql(self, expression: exp.JSONPathSubscript) -> str:
this = self.json_path_part(expression.this)
Expand Down
32 changes: 29 additions & 3 deletions sqlglot/jsonpath.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,28 @@ def _parse_bracket() -> exp.JSONPathPart:

return node

def _parse_var_text() -> str:
"""
Consumes & returns the text for a var. In BigQuery it's valid to have a key with spaces
in it, e.g JSON_QUERY(..., '$. a b c ') should produce a single JSONPathKey(' a b c ').
This is done by merging "consecutive" vars until a key separator is found (dot, colon etc)
or the path string is exhausted.
"""
prev_index = i - 2

while _match(TokenType.VAR):
pass

start = 0 if prev_index < 0 else tokens[prev_index].end + 1

if i >= len(tokens):
# This key is the last token for the path, so it's text is the remaining path
text = path[start:]
else:
text = path[start : tokens[i].start]

return text

# We canonicalize the JSON path AST so that it always starts with a
# "root" element, so paths like "field" will be generated as "$.field"
_match(TokenType.DOLLAR)
Expand All @@ -155,8 +177,10 @@ def _parse_bracket() -> exp.JSONPathPart:
if _match(TokenType.DOT) or _match(TokenType.COLON):
recursive = _prev().text == ".."

if _match(TokenType.VAR) or _match(TokenType.IDENTIFIER):
value: t.Optional[str | exp.JSONPathWildcard] = _prev().text
if _match(TokenType.VAR):
value: t.Optional[str | exp.JSONPathWildcard] = _parse_var_text()
elif _match(TokenType.IDENTIFIER):
value = _prev().text
elif _match(TokenType.STAR):
value = exp.JSONPathWildcard()
else:
Expand All @@ -170,7 +194,9 @@ def _parse_bracket() -> exp.JSONPathPart:
raise ParseError(_error("Expected key name or * after DOT"))
elif _match(TokenType.L_BRACKET):
expressions.append(_parse_bracket())
elif _match(TokenType.VAR) or _match(TokenType.IDENTIFIER):
elif _match(TokenType.VAR):
expressions.append(exp.JSONPathKey(this=_parse_var_text()))
elif _match(TokenType.IDENTIFIER):
expressions.append(exp.JSONPathKey(this=_prev().text))
elif _match(TokenType.STAR):
expressions.append(exp.JSONPathWildcard())
Expand Down
31 changes: 22 additions & 9 deletions tests/dialects/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -1574,14 +1574,6 @@ def test_bigquery(self):
"snowflake": "IFF((y) <> 0, (x) / (y), NULL)",
},
)
self.validate_all(
"""SELECT JSON_QUERY('{"class": {"students": []}}', '$.class')""",
write={
"bigquery": """SELECT JSON_QUERY('{"class": {"students": []}}', '$.class')""",
"duckdb": """SELECT '{"class": {"students": []}}' -> '$.class'""",
"snowflake": """SELECT GET_PATH(PARSE_JSON('{"class": {"students": []}}'), 'class')""",
},
)
self.validate_all(
"""SELECT JSON_VALUE_ARRAY('{"arr": [1, "a"]}', '$.arr')""",
write={
Expand Down Expand Up @@ -2139,7 +2131,16 @@ def test_null_ordering(self):
},
)

def test_json_extract_scalar(self):
def test_json_extract(self):
self.validate_all(
"""SELECT JSON_QUERY('{"class": {"students": []}}', '$.class')""",
write={
"bigquery": """SELECT JSON_QUERY('{"class": {"students": []}}', '$.class')""",
"duckdb": """SELECT '{"class": {"students": []}}' -> '$.class'""",
"snowflake": """SELECT GET_PATH(PARSE_JSON('{"class": {"students": []}}'), 'class')""",
},
)

for func in ("JSON_EXTRACT_SCALAR", "JSON_VALUE"):
with self.subTest(f"Testing BigQuery's {func}"):
self.validate_all(
Expand All @@ -2164,6 +2165,18 @@ def test_json_extract_scalar(self):
self.parse_one(sql).sql("bigquery", normalize_functions="upper"), sql
)

# Test double quote escaping
for func in ("JSON_VALUE", "JSON_QUERY", "JSON_QUERY_ARRAY"):
self.validate_identity(
f"{func}(doc, '$. a b c .d')", f"""{func}(doc, '$." a b c ".d')"""
)

# Test single quote & bracket escaping
for func in ("JSON_EXTRACT", "JSON_EXTRACT_SCALAR", "JSON_EXTRACT_ARRAY"):
self.validate_identity(
f"{func}(doc, '$. a b c .d')", f"""{func}(doc, '$[\\' a b c \\'].d')"""
)

def test_json_extract_array(self):
for func in ("JSON_QUERY_ARRAY", "JSON_EXTRACT_ARRAY"):
with self.subTest(f"Testing BigQuery's {func}"):
Expand Down

0 comments on commit 73afd0f

Please sign in to comment.