Skip to content

Commit

Permalink
fix str_position order
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Jan 17, 2023
1 parent 650e123 commit a1fb957
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 10 deletions.
2 changes: 1 addition & 1 deletion sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def locate_to_strposition(args):
)


def strposition_to_local_sql(self, expression):
def strposition_to_locate_sql(self, expression):
args = self.format_args(
expression.args.get("substr"), expression.this, expression.args.get("position")
)
Expand Down
4 changes: 2 additions & 2 deletions sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
no_safe_divide_sql,
no_trycast_sql,
rename_func,
strposition_to_local_sql,
strposition_to_locate_sql,
struct_extract_sql,
timestrtotime_sql,
var_map_sql,
Expand Down Expand Up @@ -297,7 +297,7 @@ class Generator(generator.Generator):
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
exp.SetAgg: rename_func("COLLECT_SET"),
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
exp.StrPosition: strposition_to_local_sql,
exp.StrPosition: strposition_to_locate_sql,
exp.StrToDate: _str_to_date,
exp.StrToTime: _str_to_time,
exp.StrToUnix: _str_to_unix,
Expand Down
4 changes: 2 additions & 2 deletions sqlglot/dialects/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
no_paren_current_date_sql,
no_tablesample_sql,
no_trycast_sql,
strposition_to_local_sql,
strposition_to_locate_sql,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
Expand Down Expand Up @@ -442,7 +442,7 @@ class Generator(generator.Generator):
exp.Trim: _trim_sql,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
exp.StrPosition: strposition_to_local_sql,
exp.StrPosition: strposition_to_locate_sql,
}

ROOT_PROPERTIES = {
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ class Generator(generator.Generator):
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Matches: rename_func("DECODE"),
exp.StrPosition: rename_func("POSITION"),
exp.StrPosition: lambda self, e: f"{self.normalize_func('POSITION')}({self.format_args(e.args.get('substr'), e.this, e.args.get('position'))})",
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
Expand Down
6 changes: 5 additions & 1 deletion sqlglot/dialects/tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,11 @@ class Tokenizer(tokens.Tokenizer):
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"CHARINDEX": exp.StrPosition.from_arg_list,
"CHARINDEX": lambda args: exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
position=seq_get(args, 2),
),
"ISNULL": exp.Coalesce.from_arg_list,
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3023,7 +3023,7 @@ class Substring(Func):


class StrPosition(Func):
arg_types = {"substr": True, "this": True, "position": False}
arg_types = {"this": True, "substr": True, "position": False}


class StrToDate(Func):
Expand Down
6 changes: 5 additions & 1 deletion sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2625,7 +2625,11 @@ def _parse_position(self) -> exp.Expression:
args.append(self._parse_bitwise())

# Note: we're parsing in order needle, haystack, position
this = exp.StrPosition.from_arg_list(args)
this = exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
position=seq_get(args, 2),
)
self.validate_expression(this, args)

return this
Expand Down
2 changes: 1 addition & 1 deletion tests/dialects/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ def test_operators(self):
},
)
self.validate_all(
"STR_POSITION('a', x)",
"STR_POSITION(x, 'a')",
write={
"drill": "STRPOS(x, 'a')",
"duckdb": "STRPOS(x, 'a')",
Expand Down

0 comments on commit a1fb957

Please sign in to comment.