Skip to content

Commit

Permalink
feat(clickhouse): support generate TimestampTrunc, Variance, Stddev (#…
Browse files Browse the repository at this point in the history
…3489)

* feat(clickhouse): support generate TimestampTrunc, Variance, Stddev

* add test cases

* refactor timestamptrunc_sql

* add test cases for TIMESTAMP_TRUNC
  • Loading branch information
longxiaofei authored May 17, 2024
1 parent 6dbf5dd commit 89c1d3a
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 7 deletions.
4 changes: 4 additions & 0 deletions sqlglot/dialects/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
build_json_extract_path,
rename_func,
var_map_sql,
timestamptrunc_sql,
)
from sqlglot.helper import is_int, seq_get
from sqlglot.tokens import Token, TokenType
Expand Down Expand Up @@ -761,6 +762,9 @@ class Generator(generator.Generator):
"SHA256" if e.text("length") == "256" else "SHA512", e.this
),
exp.UnixToTime: _unix_to_time_sql,
exp.TimestampTrunc: timestamptrunc_sql(zone=True),
exp.Variance: rename_func("varSamp"),
exp.Stddev: rename_func("stddevSamp"),
}

PROPERTIES_LOCATION = {
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Generator(Spark.Generator):
),
exp.DatetimeDiff: _timestamp_diff,
exp.TimestampDiff: _timestamp_diff,
exp.DatetimeTrunc: timestamptrunc_sql,
exp.DatetimeTrunc: timestamptrunc_sql(),
exp.JSONExtract: lambda self, e: self.binary(e, ":"),
exp.Select: transforms.preprocess(
[
Expand Down
10 changes: 8 additions & 2 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,8 +772,14 @@ def func(self: Generator, expression: exp.Expression) -> str:
return func


def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
return self.func("DATE_TRUNC", unit_to_str(expression), expression.this)
def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]:
def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
args = [unit_to_str(expression), expression.this]
if zone:
args.append(expression.args.get("zone"))
return self.func("DATE_TRUNC", *args)

return _timestamptrunc_sql


def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ class Generator(generator.Generator):
exp.TimestampDiff: lambda self, e: self.func(
"DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this
),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimestampTrunc: timestamptrunc_sql(),
exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, exp.DataType.Type.DATE)),
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: lambda self, e: self.func(
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ class Generator(generator.Generator):
exp.Substring: _substring_sql,
exp.TimeFromParts: rename_func("MAKE_TIME"),
exp.TimestampFromParts: rename_func("MAKE_TIMESTAMP"),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimestampTrunc: timestamptrunc_sql(),
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ class Generator(generator.Generator):
exp.StructExtract: struct_extract_sql,
exp.Table: transforms.preprocess([_unnest_sequence]),
exp.Timestamp: no_timestamp_sql,
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimestampTrunc: timestamptrunc_sql(),
exp.TimeStrToDate: timestrtotime_sql,
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: lambda self, e: self.func(
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ class Generator(generator.Generator):
exp.TimestampDiff: lambda self, e: self.func(
"TIMESTAMPDIFF", e.unit, e.expression, e.this
),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimestampTrunc: timestamptrunc_sql(),
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: lambda self, e: self.func(
"TO_CHAR", exp.cast(e.this, exp.DataType.Type.TIMESTAMP), self.format_time(e)
Expand Down
13 changes: 13 additions & 0 deletions tests/dialects/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,19 @@ def test_time(self):
},
)

self.validate_all(
"TIMESTAMP_TRUNC(x, DAY, 'UTC')",
write={
"": "TIMESTAMP_TRUNC(x, DAY, 'UTC')",
"duckdb": "DATE_TRUNC('DAY', x)",
"presto": "DATE_TRUNC('DAY', x)",
"postgres": "DATE_TRUNC('DAY', x)",
"snowflake": "DATE_TRUNC('DAY', x)",
"databricks": "DATE_TRUNC('DAY', x)",
"clickhouse": "DATE_TRUNC('DAY', x, 'UTC')",
},
)

for unit in ("DAY", "MONTH", "YEAR"):
self.validate_all(
f"{unit}(x)",
Expand Down
22 changes: 22 additions & 0 deletions tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,28 @@ def test_duckdb(self):
)
self.validate_identity("COPY lineitem (l_orderkey) TO 'orderkey.tbl' WITH (DELIMITER '|')")

self.validate_all(
"VARIANCE(a)",
write={
"duckdb": "VARIANCE(a)",
"clickhouse": "varSamp(a)",
},
)
self.validate_all(
"STDDEV(a)",
write={
"duckdb": "STDDEV(a)",
"clickhouse": "stddevSamp(a)",
},
)
self.validate_all(
"DATE_TRUNC('DAY', x)",
write={
"duckdb": "DATE_TRUNC('DAY', x)",
"clickhouse": "DATE_TRUNC('DAY', x)",
},
)

def test_array_index(self):
with self.assertLogs(helper_logger) as cm:
self.validate_all(
Expand Down

0 comments on commit 89c1d3a

Please sign in to comment.