Skip to content

Commit

Permalink
chore: switch to ruff (tobymao#2912)
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao authored Feb 3, 2024
1 parent f3bdcb0 commit f9fdf7b
Show file tree
Hide file tree
Showing 45 changed files with 307 additions and 206 deletions.
37 changes: 18 additions & 19 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,28 +1,27 @@
repos:
- repo: local
hooks:
- id: autoflake
name: autoflake
entry: autoflake -i -r
language: system
types: [ python ]
- id: ruff
name: ruff
description: "Run 'ruff' for extremely fast Python linting"
entry: ruff check
--force-exclude --fix
--ignore E721
--ignore E741
language: python
types_or: [python, pyi]
require_serial: true
additional_dependencies: []
files: ^(sqlglot/|tests/|setup.py)
- id: isort
args: [--combine-as]
name: isort
entry: isort
language: system
types: [ python ]
files: ^(sqlglot/|tests/|setup.py)
- id: ruff-format
name: ruff-format
description: "Run 'ruff format' for extremely fast Python formatting"
entry: ruff format
--force-exclude
--line-length 100
language: python
types_or: [python, pyi]
require_serial: true
- id: black
name: black
entry: black --line-length 100
language: system
types: [ python ]
require_serial: true
files: ^(sqlglot/|tests/|setup.py)
- id: mypy
name: mypy
entry: mypy sqlglot tests
Expand Down
29 changes: 15 additions & 14 deletions benchmarks/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@

# moz_sql_parser 3.10 compatibility
collections.Iterable = collections.abc.Iterable
import gc
import timeit

import numpy as np

#import sqlfluff
#import moz_sql_parser
#import sqloxide
#import sqlparse
# import sqlfluff
# import moz_sql_parser
# import sqloxide
# import sqlparse
import sqltree

import sqlglot
Expand Down Expand Up @@ -170,7 +169,7 @@ def sqlglotrs_parse(sql):


def sqltree_parse(sql):
sqltree.api.sqltree(sql.replace('"', '`').replace("''", '"'))
sqltree.api.sqltree(sql.replace('"', "`").replace("''", '"'))


def sqlparse_parse(sql):
Expand Down Expand Up @@ -206,11 +205,11 @@ def diff(row, column):
libs = [
"sqlglot",
"sqlglotrs",
#"sqlfluff",
# "sqlfluff",
"sqltree",
#"sqlparse",
#"moz_sql_parser",
#"sqloxide",
# "sqlparse",
# "moz_sql_parser",
# "sqloxide",
]
table = []

Expand All @@ -231,10 +230,12 @@ def diff(row, column):
lines.append(border(str("-" * width) for width in widths.values()))

for i, row in enumerate(table):
lines.append(border(
(str(row[column])[0:7] + diff(row, column)).rjust(width)[0 : width]
for column, width in widths.items()
))
lines.append(
border(
(str(row[column])[0:7] + diff(row, column)).rjust(width)[0:width]
for column, width in widths.items()
)
)

for line in lines:
print(line)
2 changes: 1 addition & 1 deletion pdoc/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pdoc.__main__ import cli, parser

# Need this import or else import_module doesn't work
import sqlglot
import sqlglot # noqa


def mocked_import(*args, **kwargs):
Expand Down
6 changes: 2 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,14 @@ def sqlglotrs_version():
python_requires=">=3.7",
extras_require={
"dev": [
"autoflake",
"black",
"duckdb>=0.6",
"isort",
"mypy>=0.990",
"mypy",
"pandas",
"pyspark",
"python-dateutil",
"pdoc",
"pre-commit",
"ruff",
"types-python-dateutil",
"typing_extensions",
"maturin>=1.4,<2.0",
Expand Down
7 changes: 5 additions & 2 deletions sqlglot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# ruff: noqa: F401
"""
.. include:: ../README.md
Expand Down Expand Up @@ -87,11 +88,13 @@ def parse(


@t.overload
def parse_one(sql: str, *, into: t.Type[E], **opts) -> E: ...
def parse_one(sql: str, *, into: t.Type[E], **opts) -> E:
...


@t.overload
def parse_one(sql: str, **opts) -> Expression: ...
def parse_one(sql: str, **opts) -> Expression:
...


def parse_one(
Expand Down
6 changes: 4 additions & 2 deletions sqlglot/dataframe/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,12 @@ def _create_cte_from_expression(
return cte, name

@t.overload
def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: ...
def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]:
...

@t.overload
def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: ...
def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]:
...

def _ensure_list_of_columns(self, cols):
return Column.ensure_cols(ensure_list(cols))
Expand Down
3 changes: 2 additions & 1 deletion sqlglot/dialects/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# ruff: noqa: F401
"""
## Dialects
While there is a SQL standard, most SQL engines support a variation of that standard. This makes it difficult
to write portable SQL code. SQLGlot bridges all the different variations, called "dialects", with an extensible
SQL transpilation framework.
SQL transpilation framework.
The base `sqlglot.dialects.dialect.Dialect` class implements a generic dialect that aims to be as universal as possible.
Expand Down
18 changes: 11 additions & 7 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,10 +474,12 @@ def _parse_table_parts(self, schema: bool = False) -> exp.Table:
return table

@t.overload
def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ...
def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject:
...

@t.overload
def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ...
def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg:
...

def _parse_json_object(self, agg=False):
json_object = super()._parse_json_object()
Expand Down Expand Up @@ -555,7 +557,8 @@ class Generator(generator.Generator):
exp.Create: _create_sql,
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
exp.DateAdd: date_add_interval_sql("DATE", "ADD"),
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
exp.DateDiff: lambda self,
e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
exp.DateFromParts: rename_func("DATE"),
exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: date_add_interval_sql("DATE", "SUB"),
Expand Down Expand Up @@ -598,12 +601,13 @@ class Generator(generator.Generator):
]
),
exp.SHA2: lambda self, e: self.func(
f"SHA256" if e.text("length") == "256" else "SHA512", e.this
"SHA256" if e.text("length") == "256" else "SHA512", e.this
),
exp.StabilityProperty: lambda self, e: (
f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC"
"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC"
),
exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
exp.StrToDate: lambda self,
e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
exp.StrToTime: lambda self, e: self.func(
"PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone")
),
Expand All @@ -614,7 +618,7 @@ class Generator(generator.Generator):
exp.TimestampDiff: rename_func("TIMESTAMP_DIFF"),
exp.TimestampSub: date_add_interval_sql("TIMESTAMP", "SUB"),
exp.TimeStrToTime: timestrtotime_sql,
exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression),
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsDiff: _ts_or_ds_diff_sql,
exp.TsOrDsToTime: rename_func("TIME"),
Expand Down
3 changes: 2 additions & 1 deletion sqlglot/dialects/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,8 @@ class Generator(generator.Generator):
exp.Rand: rename_func("randCanonical"),
exp.Select: transforms.preprocess([transforms.eliminate_qualify]),
exp.StartsWith: rename_func("startsWith"),
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.StrPosition: lambda self,
e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.Xor: lambda self, e: self.func("xor", e.this, e.expression, *e.expressions),
}
Expand Down
7 changes: 3 additions & 4 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def normalize_identifier(self, expression: E) -> E:
"""
if (
isinstance(expression, exp.Identifier)
and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE
and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
and (
not expression.quoted
or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
Expand Down Expand Up @@ -1020,9 +1020,8 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
"""Remove table refs from columns in when statements."""
alias = expression.this.args.get("alias")

normalize = lambda identifier: (
self.dialect.normalize_identifier(identifier).name if identifier else None
)
def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
return self.dialect.normalize_identifier(identifier).name if identifier else None

targets = {normalize(expression.this.this)}

Expand Down
9 changes: 6 additions & 3 deletions sqlglot/dialects/doris.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,14 @@ class Generator(MySQL.Generator):
exp.Map: rename_func("ARRAY_MAP"),
exp.RegexpLike: rename_func("REGEXP"),
exp.RegexpSplit: rename_func("SPLIT_BY_STRING"),
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.StrToUnix: lambda self,
e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Split: rename_func("SPLIT_BY_STRING"),
exp.TimeStrToDate: rename_func("TO_DATE"),
exp.ToChar: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", # Only for day level
exp.ToChar: lambda self,
e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TsOrDsAdd: lambda self,
e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", # Only for day level
exp.TsOrDsToDate: lambda self, e: self.func("TO_DATE", e.this),
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimestampTrunc: lambda self, e: self.func(
Expand Down
21 changes: 14 additions & 7 deletions sqlglot/dialects/drill.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,14 @@ class Generator(generator.Generator):
exp.DateAdd: _date_add_sql("ADD"),
exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: _date_add_sql("SUB"),
exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.DATEINT_FORMAT}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.DATEINT_FORMAT})",
exp.If: lambda self, e: f"`IF`({self.format_args(e.this, e.args.get('true'), e.args.get('false'))})",
exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
exp.DateToDi: lambda self,
e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.DATEINT_FORMAT}) AS INT)",
exp.DiToDate: lambda self,
e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.DATEINT_FORMAT})",
exp.If: lambda self,
e: f"`IF`({self.format_args(e.this, e.args.get('true'), e.args.get('false'))})",
exp.ILike: lambda self,
e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
Expand All @@ -141,16 +145,19 @@ class Generator(generator.Generator):
exp.Select: transforms.preprocess(
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.StrToTime: lambda self,
e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.TryCast: no_trycast_sql,
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})",
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: lambda self,
e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})",
exp.TsOrDiToDi: lambda self,
e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
}

def normalize_func(self, name: str) -> str:
Expand Down
18 changes: 12 additions & 6 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,9 +365,11 @@ class Generator(generator.Generator):
"DATE_DIFF", f"'{e.args.get('unit') or 'DAY'}'", e.expression, e.this
),
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)",
exp.DateToDi: lambda self,
e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)",
exp.Decode: lambda self, e: encode_decode_sql(self, e, "DECODE", replace=False),
exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.DATEINT_FORMAT}) AS DATE)",
exp.DiToDate: lambda self,
e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.DATEINT_FORMAT}) AS DATE)",
exp.Encode: lambda self, e: encode_decode_sql(self, e, "ENCODE", replace=False),
exp.Explode: rename_func("UNNEST"),
exp.IntDiv: lambda self, e: self.binary(e, "//"),
Expand Down Expand Up @@ -408,7 +410,8 @@ class Generator(generator.Generator):
exp.StrPosition: str_position_sql,
exp.StrToDate: lambda self, e: f"CAST({str_to_time_sql(self, e)} AS DATE)",
exp.StrToTime: str_to_time_sql,
exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.StrToUnix: lambda self,
e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.Struct: _struct_sql,
exp.Timestamp: no_timestamp_sql,
exp.TimestampDiff: lambda self, e: self.func(
Expand All @@ -418,17 +421,20 @@ class Generator(generator.Generator):
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))",
exp.TimeToStr: lambda self, e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToStr: lambda self,
e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("EPOCH"),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
exp.TsOrDiToDi: lambda self,
e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsDiff: lambda self, e: self.func(
"DATE_DIFF",
f"'{e.args.get('unit') or 'DAY'}'",
exp.cast(e.expression, "TIMESTAMP"),
exp.cast(e.this, "TIMESTAMP"),
),
exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToStr: lambda self,
e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: _unix_to_time_sql,
exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)",
exp.VariancePop: rename_func("VAR_POP"),
Expand Down
Loading

0 comments on commit f9fdf7b

Please sign in to comment.