Skip to content

Commit

Permalink
feat(tsql): SPLIT_PART function and conversion to PARSENAME in tsql (#…
Browse files Browse the repository at this point in the history
…4211)

* feat(tsql): SPLIT_PART function and conversion to PARSENAME in tsql

* Fix: restore RIGHT in FUNCTIONS

* feat(tsql): SPLIT_NAME function and conversion to PARSENAME in tsql (Enhanced)
  • Loading branch information
daihuynh authored Oct 8, 2024
1 parent dcdec95 commit 163e943
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 1 deletion.
46 changes: 46 additions & 0 deletions sqlglot/dialects/tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,28 @@ def _parse(args: t.List[exp.Expression]) -> exp.Expression:
return _parse


# https://learn.microsoft.com/en-us/sql/t-sql/functions/parsename-transact-sql?view=sql-server-ver16
def _build_parsename(args: t.List) -> exp.SplitPart | exp.Anonymous:
anonymous = exp.Anonymous(this="PARSENAME", expressions=args)
# Not correct number of arguments or
# any argument is not literal
if len(args) != 2 or isinstance(args[0], exp.Column) or isinstance(args[1], exp.Column):
return anonymous
arg_this: exp.Expression = args[0]
text = arg_this.name
part_count = len(text.split("."))
if part_count > 4:
return anonymous

arg_partnum: exp.Expression = args[1]
part_num = int(arg_partnum.name)
length = 1 if isinstance(arg_this, exp.Null) else part_count + 1 # Reverse index
idx = 0 if isinstance(arg_this, exp.Null) else part_num
return exp.SplitPart(
this=arg_this, delimiter=exp.Literal.string("."), part_num=exp.Literal.number(length - idx)
)


def _build_json_query(args: t.List, dialect: Dialect) -> exp.JSONExtract:
if len(args) == 1:
# The default value for path is '$'. As a result, if you don't provide a
Expand Down Expand Up @@ -543,6 +565,7 @@ class Parser(parser.Parser):
"LEN": _build_with_arg_as_text(exp.Length),
"LEFT": _build_with_arg_as_text(exp.Left),
"RIGHT": _build_with_arg_as_text(exp.Right),
"PARSENAME": _build_parsename,
"REPLICATE": exp.Repeat.from_arg_list,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"SYSDATETIME": exp.CurrentTimestamp.from_arg_list,
Expand Down Expand Up @@ -954,6 +977,29 @@ def lateral_op(self, expression: exp.Lateral) -> str:
self.unsupported("LATERAL clause is not supported.")
return "LATERAL"

def splitpart_sql(self: TSQL.Generator, expression: exp.SplitPart) -> str:
delimiter: exp.Expression = expression.args["delimiter"]
if delimiter.name != ".":
self.unsupported("PARSENAME only supports '.' delimiter")
return self.sql(expression.this)

arg_this: exp.Expression = expression.args["this"]
arg_partnum: exp.Expression = expression.args["part_num"]
if isinstance(arg_this, exp.Column) or isinstance(arg_partnum, exp.Column):
self.unsupported(
"PARSENAME cannot calculate object_piece based on column-type arguments"
)
return self.sql(expression.this)

text = arg_this.name
part_count = len(text.split("."))

part_num = int(arg_partnum.name)

length = 1 if isinstance(arg_this, exp.Null) else part_count + 1 # Reverse index
idx = 0 if isinstance(arg_this, exp.Null) else part_num
return self.func("PARSENAME", arg_this, exp.Literal.number(length - idx))

def timefromparts_sql(self, expression: exp.TimeFromParts) -> str:
nano = expression.args.get("nano")
if nano is not None:
Expand Down
5 changes: 5 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6248,6 +6248,11 @@ class Split(Func):
arg_types = {"this": True, "expression": True, "limit": False}


# https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.split_part.html
class SplitPart(Func):
arg_types = {"this": True, "delimiter": True, "part_num": True}


# Start may be omitted in the case of postgres
# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6
class Substring(Func):
Expand Down
70 changes: 69 additions & 1 deletion tests/dialects/test_tsql.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sqlglot import exp, parse, parse_one
from tests.dialects.test_dialect import Validator
from sqlglot.errors import ParseError
from sqlglot.errors import ParseError, UnsupportedError
from sqlglot.optimizer.annotate_types import annotate_types


Expand Down Expand Up @@ -2014,3 +2014,71 @@ def test_grant(self):
self.validate_identity(
"GRANT EXECUTE ON TestProc TO User2 AS TesterRole", check_command_warning=True
)

def test_parsename(self):
# Test default case
self.validate_all(
"SELECT PARSENAME('1.2.3', 1)",
read={
"spark": "SELECT SPLIT_PART('1.2.3', '.', 3)",
"databricks": "SELECT SPLIT_PART('1.2.3', '.', 3)",
},
write={
"spark": "SELECT SPLIT_PART('1.2.3', '.', 3)",
"databricks": "SELECT SPLIT_PART('1.2.3', '.', 3)",
"tsql": "SELECT PARSENAME('1.2.3', 1)",
},
)
# Test zero index
self.validate_all(
"SELECT PARSENAME('1.2.3', 0)",
read={
"spark": "SELECT SPLIT_PART('1.2.3', '.', 4)",
"databricks": "SELECT SPLIT_PART('1.2.3', '.', 4)",
},
write={
"spark": "SELECT SPLIT_PART('1.2.3', '.', 4)",
"databricks": "SELECT SPLIT_PART('1.2.3', '.', 4)",
"tsql": "SELECT PARSENAME('1.2.3', 0)",
},
)
# Test null value
self.validate_all(
"SELECT PARSENAME(NULL, 1)",
read={
"spark": "SELECT SPLIT_PART(NULL, '.', 1)",
"databricks": "SELECT SPLIT_PART(NULL, '.', 1)",
},
write={
"spark": "SELECT SPLIT_PART(NULL, '.', 1)",
"databricks": "SELECT SPLIT_PART(NULL, '.', 1)",
"tsql": "SELECT PARSENAME(NULL, 1)",
},
)
# Test non-dot delimiter
self.validate_all(
"SELECT SPLIT_PART('1,2,3', ',', 1)",
write={
"spark": "SELECT SPLIT_PART('1,2,3', ',', 1)",
"databricks": "SELECT SPLIT_PART('1,2,3', ',', 1)",
"tsql": UnsupportedError,
},
)
# Test non-dot delimiter
self.validate_all(
"SELECT SPLIT_PART('1.2.3.4.5', '.', 1)",
write={
"spark": "SELECT SPLIT_PART('1.2.3.4.5', '.', 1)",
"databricks": "SELECT SPLIT_PART('1.2.3.4.5', '.', 1)",
"tsql": "SELECT PARSENAME('1.2.3.4.5', 5)",
},
)
# Test column-type parameters
self.validate_all(
"WITH t AS (SELECT 'a.b.c' AS value, 1 AS idx) SELECT SPLIT_PART(value, '.', idx) FROM t",
write={
"spark": "WITH t AS (SELECT 'a.b.c' AS value, 1 AS idx) SELECT SPLIT_PART(value, '.', idx) FROM t",
"databricks": "WITH t AS (SELECT 'a.b.c' AS value, 1 AS idx) SELECT SPLIT_PART(value, '.', idx) FROM t",
"tsql": UnsupportedError,
},
)

0 comments on commit 163e943

Please sign in to comment.