From a1fb957b32c82a2a30cbc90d22e2f55e3d56ce4e Mon Sep 17 00:00:00 2001 From: tobymao Date: Tue, 17 Jan 2023 14:51:45 -0800 Subject: [PATCH] fix str_position order --- sqlglot/dialects/dialect.py | 2 +- sqlglot/dialects/hive.py | 4 ++-- sqlglot/dialects/mysql.py | 4 ++-- sqlglot/dialects/snowflake.py | 2 +- sqlglot/dialects/tsql.py | 6 +++++- sqlglot/expressions.py | 2 +- sqlglot/parser.py | 6 +++++- tests/dialects/test_dialect.py | 2 +- 8 files changed, 18 insertions(+), 10 deletions(-) diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 1c840da1b3..eb1fc3f6bf 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -368,7 +368,7 @@ def locate_to_strposition(args): ) -def strposition_to_local_sql(self, expression): +def strposition_to_locate_sql(self, expression): args = self.format_args( expression.args.get("substr"), expression.this, expression.args.get("position") ) diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index ead13b16e3..ddfd1e8ed0 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -13,7 +13,7 @@ no_safe_divide_sql, no_trycast_sql, rename_func, - strposition_to_local_sql, + strposition_to_locate_sql, struct_extract_sql, timestrtotime_sql, var_map_sql, @@ -297,7 +297,7 @@ class Generator(generator.Generator): exp.SchemaCommentProperty: lambda self, e: self.naked_property(e), exp.SetAgg: rename_func("COLLECT_SET"), exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))", - exp.StrPosition: strposition_to_local_sql, + exp.StrPosition: strposition_to_locate_sql, exp.StrToDate: _str_to_date, exp.StrToTime: _str_to_time, exp.StrToUnix: _str_to_unix, diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 0fd79926af..e99506b293 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -10,7 +10,7 @@ no_paren_current_date_sql, no_tablesample_sql, no_trycast_sql, - strposition_to_local_sql, + strposition_to_locate_sql, ) from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -442,7 +442,7 @@ class Generator(generator.Generator): exp.Trim: _trim_sql, exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")), - exp.StrPosition: strposition_to_local_sql, + exp.StrPosition: strposition_to_locate_sql, } ROOT_PROPERTIES = { diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 24d3bdf59d..3d261838f4 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -218,7 +218,7 @@ class Generator(generator.Generator): exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}", exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.Matches: rename_func("DECODE"), - exp.StrPosition: rename_func("POSITION"), + exp.StrPosition: lambda self, e: f"{self.normalize_func('POSITION')}({self.format_args(e.args.get('substr'), e.this, e.args.get('position'))})", exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeStrToTime: timestrtotime_sql, exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index a1a14c0b09..80cb69c20a 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -264,7 +264,11 @@ class Tokenizer(tokens.Tokenizer): class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore - "CHARINDEX": exp.StrPosition.from_arg_list, + "CHARINDEX": lambda args: exp.StrPosition( + this=seq_get(args, 1), + substr=seq_get(args, 0), + position=seq_get(args, 2), + ), "ISNULL": exp.Coalesce.from_arg_list, "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), "DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 0f2e9f3af5..3f82e0e7f5 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -3023,7 +3023,7 @@ class Substring(Func): class StrPosition(Func): - arg_types = {"substr": True, "this": True, "position": False} + arg_types = {"this": True, "substr": True, "position": False} class StrToDate(Func): diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 9c99153eaf..f415a08569 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -2625,7 +2625,11 @@ def _parse_position(self) -> exp.Expression: args.append(self._parse_bitwise()) # Note: we're parsing in order needle, haystack, position - this = exp.StrPosition.from_arg_list(args) + this = exp.StrPosition( + this=seq_get(args, 1), + substr=seq_get(args, 0), + position=seq_get(args, 2), + ) self.validate_expression(this, args) return this diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 337d6a0c7f..c9f85852d2 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -955,7 +955,7 @@ def test_operators(self): }, ) self.validate_all( - "STR_POSITION('a', x)", + "STR_POSITION(x, 'a')", write={ "drill": "STRPOS(x, 'a')", "duckdb": "STRPOS(x, 'a')",