Skip to content

Commit

Permalink
Feat: add RETURNS NULL ON NULL and STRICT properties (#3504)
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored May 18, 2024
1 parent e281db8 commit 9aee21b
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 10 deletions.
6 changes: 5 additions & 1 deletion sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2700,7 +2700,11 @@ class RemoteWithConnectionModelProperty(Property):


class ReturnsProperty(Property):
arg_types = {"this": True, "is_table": False, "table": False}
arg_types = {"this": False, "is_table": False, "table": False, "null": False}


class StrictProperty(Property):
arg_types = {}


class RowFormatProperty(Property):
Expand Down
6 changes: 5 additions & 1 deletion sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ class Generator(metaclass=_Generator):
exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
exp.RemoteWithConnectionModelProperty: lambda self,
e: f"REMOTE WITH CONNECTION {self.sql(e, 'this')}",
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
exp.ReturnsProperty: lambda self, e: (
"RETURNS NULL ON NULL INPUT" if e.args.get("null") else self.naked_property(e)
),
exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}",
exp.SetConfigProperty: lambda self, e: self.sql(e, "this"),
exp.SetProperty: lambda _, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
Expand All @@ -135,6 +137,7 @@ class Generator(metaclass=_Generator):
exp.SqlSecurityProperty: lambda _,
e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
exp.StabilityProperty: lambda _, e: e.name,
exp.StrictProperty: lambda *_: "STRICT",
exp.TemporaryProperty: lambda *_: "TEMPORARY",
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
exp.Timestamp: lambda self, e: self.func("TIMESTAMP", e.this, e.expression),
Expand Down Expand Up @@ -476,6 +479,7 @@ class Generator(metaclass=_Generator):
exp.SqlReadWriteProperty: exp.Properties.Location.POST_SCHEMA,
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA,
exp.StrictProperty: exp.Properties.Location.POST_SCHEMA,
exp.TemporaryProperty: exp.Properties.Location.POST_CREATE,
exp.ToTableProperty: exp.Properties.Location.POST_SCHEMA,
exp.TransientProperty: exp.Properties.Location.POST_CREATE,
Expand Down
7 changes: 6 additions & 1 deletion sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,7 @@ class Parser(metaclass=_Parser):
"READS": lambda self: self._parse_reads_property(),
"REMOTE": lambda self: self._parse_remote_with_connection(),
"RETURNS": lambda self: self._parse_returns(),
"STRICT": lambda self: self.expression(exp.StrictProperty),
"ROW": lambda self: self._parse_row(),
"ROW_FORMAT": lambda self: self._parse_property_assignment(exp.RowFormatProperty),
"SAMPLE": lambda self: self.expression(
Expand Down Expand Up @@ -2309,6 +2310,7 @@ def _parse_remote_with_connection(self) -> exp.RemoteWithConnectionModelProperty

def _parse_returns(self) -> exp.ReturnsProperty:
value: t.Optional[exp.Expression]
null = None
is_table = self._match(TokenType.TABLE)

if is_table:
Expand All @@ -2322,10 +2324,13 @@ def _parse_returns(self) -> exp.ReturnsProperty:
self.raise_error("Expecting >")
else:
value = self._parse_schema(exp.var("TABLE"))
elif self._match_text_seq("NULL", "ON", "NULL", "INPUT"):
null = True
value = None
else:
value = self._parse_types()

return self.expression(exp.ReturnsProperty, this=value, is_table=is_table)
return self.expression(exp.ReturnsProperty, this=value, is_table=is_table, null=null)

def _parse_describe(self) -> exp.Describe:
kind = self._match_set(self.CREATABLES) and self._prev.text
Expand Down
12 changes: 5 additions & 7 deletions tests/dialects/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,23 +805,21 @@ def test_ddl(self):
"CREATE TABLE test (x TIMESTAMP[][])",
)
self.validate_identity(
"CREATE FUNCTION add(INT, INT) RETURNS INT SET search_path TO 'public' AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE",
check_command_warning=True,
"CREATE FUNCTION add(integer, integer) RETURNS INT LANGUAGE SQL IMMUTABLE RETURNS NULL ON NULL INPUT AS 'select $1 + $2;'",
)
self.validate_identity(
"CREATE FUNCTION x(INT) RETURNS INT SET foo FROM CURRENT",
check_command_warning=True,
"CREATE FUNCTION add(integer, integer) RETURNS INT LANGUAGE SQL IMMUTABLE STRICT AS 'select $1 + $2;'"
)
self.validate_identity(
"CREATE FUNCTION add(integer, integer) RETURNS integer AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE RETURNS NULL ON NULL INPUT",
"CREATE FUNCTION add(INT, INT) RETURNS INT SET search_path TO 'public' AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE",
check_command_warning=True,
)
self.validate_identity(
"CREATE FUNCTION add(integer, integer) RETURNS integer AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE CALLED ON NULL INPUT",
"CREATE FUNCTION x(INT) RETURNS INT SET foo FROM CURRENT",
check_command_warning=True,
)
self.validate_identity(
"CREATE FUNCTION add(integer, integer) RETURNS integer AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE STRICT",
"CREATE FUNCTION add(integer, integer) RETURNS integer AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE CALLED ON NULL INPUT",
check_command_warning=True,
)
self.validate_identity(
Expand Down
6 changes: 6 additions & 0 deletions tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,12 @@ def test_ddl(self):
self.validate_identity(
"CREATE ICEBERG TABLE my_iceberg_table (amount ARRAY(INT)) CATALOG='SNOWFLAKE' EXTERNAL_VOLUME='my_external_volume' BASE_LOCATION='my/relative/path/from/extvol'"
)
self.validate_identity(
"""CREATE OR REPLACE FUNCTION ibis_udfs.public.object_values("obj" OBJECT) RETURNS ARRAY LANGUAGE JAVASCRIPT RETURNS NULL ON NULL INPUT AS ' return Object.values(obj) '"""
)
self.validate_identity(
"""CREATE OR REPLACE FUNCTION ibis_udfs.public.object_values("obj" OBJECT) RETURNS ARRAY LANGUAGE JAVASCRIPT STRICT AS ' return Object.values(obj) '"""
)
self.validate_identity(
"CREATE OR REPLACE FUNCTION my_udf(location OBJECT(city VARCHAR, zipcode DECIMAL(38, 0), val ARRAY(BOOLEAN))) RETURNS VARCHAR AS $$ SELECT 'foo' $$",
"CREATE OR REPLACE FUNCTION my_udf(location OBJECT(city VARCHAR, zipcode DECIMAL(38, 0), val ARRAY(BOOLEAN))) RETURNS VARCHAR AS ' SELECT \\'foo\\' '",
Expand Down

0 comments on commit 9aee21b

Please sign in to comment.