Skip to content

Commit

Permalink
PR Feedback 1
Browse files Browse the repository at this point in the history
  • Loading branch information
VaggelisD committed Oct 3, 2024
1 parent 39d12e5 commit 13f6889
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 41 deletions.
62 changes: 31 additions & 31 deletions sqlglot/dialects/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
)
from sqlglot.dialects.hive import Hive
from sqlglot.dialects.mysql import MySQL
from sqlglot.generator import Generator
from sqlglot.helper import apply_index_offset, seq_get
from sqlglot.tokens import TokenType
from sqlglot.transforms import unqualify_columns
Expand Down Expand Up @@ -167,35 +166,6 @@ def _unix_to_time_sql(self: Presto.Generator, expression: exp.UnixToTime) -> str
return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / POW(10, {scale}))"


def jsonextract_sql(self: Generator, expression: exp.JSONExtract) -> str:
is_json_extract = self.dialect.settings.get("variant_extract_is_json_extract", True)

# Generate JSON_EXTRACT unless the user has configured that a Snowflake / Databricks
# VARIANT extract (e.g. col:x.y) should map to dot notation (i.e ROW access) in Presto/Trino
if not expression.args.get("variant_extract") or is_json_extract:
return self.func(
"JSON_EXTRACT", expression.this, expression.expression, *expression.expressions
)

this = self.sql(expression, "this")

# Convert the JSONPath extraction `JSON_EXTRACT(col, '$.x.y) to a ROW access col.x.y
segments = []
for path_key in expression.expression.expressions[1:]:
if not isinstance(path_key, exp.JSONPathKey):
# Cannot transpile subscripts, wildcards etc to dot notation
self.unsupported(f"Cannot transpile JSONPath segment '{path_key}' to ROW access")
continue
key = path_key.this
if not exp.SAFE_IDENTIFIER_RE.match(key):
key = f'"{key}"'
segments.append(f".{key}")

expr = "".join(segments)

return f"{this}{expr}"


def _to_int(self: Presto.Generator, expression: exp.Expression) -> exp.Expression:
if not expression.type:
from sqlglot.optimizer.annotate_types import annotate_types
Expand Down Expand Up @@ -436,7 +406,7 @@ class Generator(generator.Generator):
exp.If: if_sql(),
exp.ILike: no_ilike_sql,
exp.Initcap: _initcap_sql,
exp.JSONExtract: jsonextract_sql,
exp.JSONExtract: lambda self, e: self.jsonextract_sql(e),
exp.Last: _first_last_sql,
exp.LastValue: _first_last_sql,
exp.LastDay: lambda self, e: self.func("LAST_DAY_OF_MONTH", e.this),
Expand Down Expand Up @@ -694,3 +664,33 @@ def delete_sql(self, expression: exp.Delete) -> str:
expression = t.cast(exp.Delete, expression.transform(unqualify_columns))

return super().delete_sql(expression)

def jsonextract_sql(self, expression: exp.JSONExtract) -> str:
is_json_extract = self.dialect.settings.get("variant_extract_is_json_extract", True)

# Generate JSON_EXTRACT unless the user has configured that a Snowflake / Databricks
# VARIANT extract (e.g. col:x.y) should map to dot notation (i.e ROW access) in Presto/Trino
if not expression.args.get("variant_extract") or is_json_extract:
return self.func(
"JSON_EXTRACT", expression.this, expression.expression, *expression.expressions
)

this = self.sql(expression, "this")

# Convert the JSONPath extraction `JSON_EXTRACT(col, '$.x.y) to a ROW access col.x.y
segments = []
for path_key in expression.expression.expressions[1:]:
if not isinstance(path_key, exp.JSONPathKey):
# Cannot transpile subscripts, wildcards etc to dot notation
self.unsupported(
f"Cannot transpile JSONPath segment '{path_key}' to ROW access"
)
continue
key = path_key.this
if not exp.SAFE_IDENTIFIER_RE.match(key):
key = f'"{key}"'
segments.append(f".{key}")

expr = "".join(segments)

return f"{this}{expr}"
16 changes: 7 additions & 9 deletions sqlglot/dialects/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from sqlglot import exp, parser
from sqlglot.dialects.dialect import merge_without_target_sql, trim_sql, timestrtotime_sql
from sqlglot.dialects.presto import Presto, jsonextract_sql
from sqlglot.dialects.presto import Presto
from sqlglot.tokens import TokenType


Expand Down Expand Up @@ -37,7 +37,7 @@ def _parse_json_query(self):
option = self._parse_var_from_options(self.JSON_QUERY_OPTIONS, raise_unmatched=False)

return self.expression(
exp.JSONExtract, this=this, expression=expr, expressions=[option], json_query=True
exp.JSONExtract, this=this, expression=expr, json_query=True, option=option
)

class Generator(Presto.Generator):
Expand All @@ -48,7 +48,7 @@ class Generator(Presto.Generator):
exp.Merge: merge_without_target_sql,
exp.TimeStrToTime: lambda self, e: timestrtotime_sql(self, e, include_precision=True),
exp.Trim: trim_sql,
exp.JSONExtract: lambda self, e: self._jsonextract_sql(e),
exp.JSONExtract: lambda self, e: self.jsonextract_sql(e),
}

SUPPORTED_JSON_PATH_PARTS = {
Expand All @@ -57,17 +57,15 @@ class Generator(Presto.Generator):
exp.JSONPathSubscript,
}

def _jsonextract_sql(self, expression: exp.JSONExtract) -> str:
def jsonextract_sql(self, expression: exp.JSONExtract) -> str:
if not expression.args.get("json_query"):
return jsonextract_sql(self, expression)
return super().jsonextract_sql(expression)

this = self.sql(expression, "this")
json_path = self.sql(expression, "expression")

option = self.expressions(expression, flat=True)
option = self.sql(expression, "option")
option = f" {option}" if option else ""

return self.func("JSON_QUERY", this, json_path + option)
return self.func("JSON_QUERY", expression.this, json_path + option)

class Tokenizer(Presto.Tokenizer):
HEX_STRINGS = [("X'", "'")]
1 change: 1 addition & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5739,6 +5739,7 @@ class JSONExtract(Binary, Func):
"expressions": False,
"variant_extract": False,
"json_query": False,
"option": False,
}
_sql_names = ["JSON_EXTRACT"]
is_var_len_args = True
Expand Down
3 changes: 2 additions & 1 deletion tests/dialects/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ class TestTrino(Validator):
def test_trino(self):
self.validate_identity("JSON_EXTRACT(content, json_path)")
self.validate_identity("JSON_QUERY(content, 'lax $.HY.*')")
self.validate_identity("JSON_QUERY(content, 'strict $.HY.*' WITHOUT UNCONDITIONAL WRAPPER)")
self.validate_identity("JSON_QUERY(content, 'strict $.HY.*' WITH UNCONDITIONAL WRAPPER)")
self.validate_identity("JSON_QUERY(content, 'strict $.HY.*' WITHOUT CONDITIONAL WRAPPER)")

def test_trim(self):
self.validate_identity("SELECT TRIM('!' FROM '!foo!')")
Expand Down

0 comments on commit 13f6889

Please sign in to comment.