Skip to content

Commit

Permalink
Feat(bigquery): improve support for CREATE MODEL DDL statement (#2380)
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored Oct 5, 2023
1 parent dd8334d commit 3266e51
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 23 deletions.
7 changes: 4 additions & 3 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,15 +260,16 @@ class Tokenizer(tokens.Tokenizer):
"ANY TYPE": TokenType.VARIANT,
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
"BYTES": TokenType.BINARY,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
"DECLARE": TokenType.COMMAND,
"FLOAT64": TokenType.DOUBLE,
"FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
"INT64": TokenType.BIGINT,
"MODEL": TokenType.MODEL,
"NOT DETERMINISTIC": TokenType.VOLATILE,
"RECORD": TokenType.STRUCT,
"TIMESTAMP": TokenType.TIMESTAMPTZ,
"NOT DETERMINISTIC": TokenType.VOLATILE,
"FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
}
KEYWORDS.pop("DIV")

Expand Down
20 changes: 18 additions & 2 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2040,8 +2040,12 @@ class FreespaceProperty(Property):
arg_types = {"this": True, "percent": False}


class InputOutputFormat(Expression):
arg_types = {"input_format": False, "output_format": False}
class InputModelProperty(Property):
arg_types = {"this": True}


class OutputModelProperty(Property):
arg_types = {"this": True}


class IsolatedLoadingProperty(Property):
Expand Down Expand Up @@ -2137,6 +2141,10 @@ class PartitionedByProperty(Property):
arg_types = {"this": True}


class RemoteWithConnectionModelProperty(Property):
arg_types = {"this": True}


class ReturnsProperty(Property):
arg_types = {"this": True, "is_table": False, "table": False}

Expand Down Expand Up @@ -2211,6 +2219,10 @@ class TemporaryProperty(Property):
arg_types = {}


class TransformModelProperty(Property):
arg_types = {"expressions": True}


class TransientProperty(Property):
arg_types = {"this": False}

Expand Down Expand Up @@ -2293,6 +2305,10 @@ class Qualify(Expression):
pass


class InputOutputFormat(Expression):
arg_types = {"input_format": False, "output_format": False}


# https://www.ibm.com/docs/en/ias?topic=procedures-return-statement-in-sql
class Return(Expression):
pass
Expand Down
8 changes: 8 additions & 0 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class Generator:
exp.ExternalProperty: lambda self, e: "EXTERNAL",
exp.HeapProperty: lambda self, e: "HEAP",
exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}",
exp.InputModelProperty: lambda self, e: f"INPUT{self.sql(e, 'this')}",
exp.IntervalSpan: lambda self, e: f"{self.sql(e, 'this')} TO {self.sql(e, 'expression')}",
exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e),
Expand All @@ -84,7 +85,9 @@ class Generator:
exp.OnCommitProperty: lambda self, e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS",
exp.OnProperty: lambda self, e: f"ON {self.sql(e, 'this')}",
exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}",
exp.OutputModelProperty: lambda self, e: f"OUTPUT{self.sql(e, 'this')}",
exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
exp.RemoteWithConnectionModelProperty: lambda self, e: f"REMOTE WITH CONNECTION {self.sql(e, 'this')}",
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}",
exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
Expand All @@ -94,6 +97,7 @@ class Generator:
exp.TemporaryProperty: lambda self, e: f"TEMPORARY",
exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
exp.TransientProperty: lambda self, e: "TRANSIENT",
exp.TransformModelProperty: lambda self, e: self.func("TRANSFORM", *e.expressions),
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
exp.UppercaseColumnConstraint: lambda self, e: f"UPPERCASE",
exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]),
Expand Down Expand Up @@ -278,6 +282,7 @@ class Generator:
exp.FileFormatProperty: exp.Properties.Location.POST_WITH,
exp.FreespaceProperty: exp.Properties.Location.POST_NAME,
exp.HeapProperty: exp.Properties.Location.POST_WITH,
exp.InputModelProperty: exp.Properties.Location.POST_SCHEMA,
exp.IsolatedLoadingProperty: exp.Properties.Location.POST_NAME,
exp.JournalProperty: exp.Properties.Location.POST_NAME,
exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA,
Expand All @@ -291,9 +296,11 @@ class Generator:
exp.OnProperty: exp.Properties.Location.POST_SCHEMA,
exp.OnCommitProperty: exp.Properties.Location.POST_EXPRESSION,
exp.Order: exp.Properties.Location.POST_SCHEMA,
exp.OutputModelProperty: exp.Properties.Location.POST_SCHEMA,
exp.PartitionedByProperty: exp.Properties.Location.POST_WITH,
exp.PrimaryKey: exp.Properties.Location.POST_SCHEMA,
exp.Property: exp.Properties.Location.POST_WITH,
exp.RemoteWithConnectionModelProperty: exp.Properties.Location.POST_SCHEMA,
exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA,
exp.RowFormatProperty: exp.Properties.Location.POST_SCHEMA,
exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA,
Expand All @@ -310,6 +317,7 @@ class Generator:
exp.TemporaryProperty: exp.Properties.Location.POST_CREATE,
exp.ToTableProperty: exp.Properties.Location.POST_SCHEMA,
exp.TransientProperty: exp.Properties.Location.POST_CREATE,
exp.TransformModelProperty: exp.Properties.Location.POST_SCHEMA,
exp.MergeTreeTTL: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION,
Expand Down
13 changes: 13 additions & 0 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ class Parser(metaclass=_Parser):
TokenType.SCHEMA,
TokenType.TABLE,
TokenType.VIEW,
TokenType.MODEL,
TokenType.DICTIONARY,
}

Expand Down Expand Up @@ -649,6 +650,7 @@ class Parser(metaclass=_Parser):
"IMMUTABLE": lambda self: self.expression(
exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE")
),
"INPUT": lambda self: self.expression(exp.InputModelProperty, this=self._parse_schema()),
"JOURNAL": lambda self, **kwargs: self._parse_journal(**kwargs),
"LANGUAGE": lambda self: self._parse_property_assignment(exp.LanguageProperty),
"LAYOUT": lambda self: self._parse_dict_property(this="LAYOUT"),
Expand All @@ -664,11 +666,13 @@ class Parser(metaclass=_Parser):
"NO": lambda self: self._parse_no_property(),
"ON": lambda self: self._parse_on_property(),
"ORDER BY": lambda self: self._parse_order(skip_order_token=True),
"OUTPUT": lambda self: self.expression(exp.OutputModelProperty, this=self._parse_schema()),
"PARTITION BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED_BY": lambda self: self._parse_partitioned_by(),
"PRIMARY KEY": lambda self: self._parse_primary_key(in_props=True),
"RANGE": lambda self: self._parse_dict_range(this="RANGE"),
"REMOTE": lambda self: self._parse_remote_with_connection(),
"RETURNS": lambda self: self._parse_returns(),
"ROW": lambda self: self._parse_row(),
"ROW_FORMAT": lambda self: self._parse_property_assignment(exp.RowFormatProperty),
Expand All @@ -690,6 +694,9 @@ class Parser(metaclass=_Parser):
"TEMPORARY": lambda self: self.expression(exp.TemporaryProperty),
"TO": lambda self: self._parse_to_table(),
"TRANSIENT": lambda self: self.expression(exp.TransientProperty),
"TRANSFORM": lambda self: self.expression(
exp.TransformModelProperty, expressions=self._parse_wrapped_csv(self._parse_expression)
),
"TTL": lambda self: self._parse_ttl(),
"USING": lambda self: self._parse_property_assignment(exp.FileFormatProperty),
"VOLATILE": lambda self: self._parse_volatile_property(),
Expand Down Expand Up @@ -1788,6 +1795,12 @@ def _parse_character_set(self, default: bool = False) -> exp.CharacterSetPropert
exp.CharacterSetProperty, this=self._parse_var_or_string(), default=default
)

def _parse_remote_with_connection(self) -> exp.RemoteWithConnectionModelProperty:
self._match_text_seq("WITH", "CONNECTION")
return self.expression(
exp.RemoteWithConnectionModelProperty, this=self._parse_table_parts()
)

def _parse_returns(self) -> exp.ReturnsProperty:
value: t.Optional[exp.Expression]
is_table = self._match(TokenType.TABLE)
Expand Down
1 change: 1 addition & 0 deletions sqlglot/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ class TokenType(AutoName):
MEMBER_OF = auto()
MERGE = auto()
MOD = auto()
MODEL = auto()
NATURAL = auto()
NEXT = auto()
NOTNULL = auto()
Expand Down
75 changes: 57 additions & 18 deletions tests/dialects/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,24 +85,6 @@ def test_bigquery(self):
self.validate_identity("ROLLBACK TRANSACTION")
self.validate_identity("CAST(x AS BIGNUMERIC)")
self.validate_identity("SELECT y + 1 FROM x GROUP BY y + 1 ORDER BY 1")
self.validate_identity(
"SELECT * FROM ML.PREDICT(MODEL mydataset.mymodel, (SELECT label, column1, column2 FROM mydataset.mytable))"
)
self.validate_identity(
"SELECT label, predicted_label1, predicted_label AS predicted_label2 FROM ML.PREDICT(MODEL mydataset.mymodel2, (SELECT * EXCEPT (predicted_label), predicted_label AS predicted_label1 FROM ML.PREDICT(MODEL mydataset.mymodel1, TABLE mydataset.mytable)))"
)
self.validate_identity(
"SELECT * FROM ML.PREDICT(MODEL mydataset.mymodel, (SELECT custom_label, column1, column2 FROM mydataset.mytable), STRUCT(0.55 AS threshold))"
)
self.validate_identity(
"SELECT * FROM ML.PREDICT(MODEL `my_project`.my_dataset.my_model, (SELECT * FROM input_data))"
)
self.validate_identity(
"SELECT * FROM ML.PREDICT(MODEL my_dataset.vision_model, (SELECT uri, ML.RESIZE_IMAGE(ML.DECODE_IMAGE(data), 480, 480, FALSE) AS input FROM my_dataset.object_table))"
)
self.validate_identity(
"SELECT * FROM ML.PREDICT(MODEL my_dataset.vision_model, (SELECT uri, ML.CONVERT_COLOR_SPACE(ML.RESIZE_IMAGE(ML.DECODE_IMAGE(data), 224, 280, TRUE), 'YIQ') AS input FROM my_dataset.object_table WHERE content_type = 'image/jpeg'))"
)
self.validate_identity(
"DATE(CAST('2016-12-25 05:30:00+07' AS DATETIME), 'America/Los_Angeles')"
)
Expand Down Expand Up @@ -822,6 +804,63 @@ def test_remove_precision_parameterized_types(self):
},
)

def test_models(self):
self.validate_identity(
"SELECT * FROM ML.PREDICT(MODEL mydataset.mymodel, (SELECT label, column1, column2 FROM mydataset.mytable))"
)
self.validate_identity(
"SELECT label, predicted_label1, predicted_label AS predicted_label2 FROM ML.PREDICT(MODEL mydataset.mymodel2, (SELECT * EXCEPT (predicted_label), predicted_label AS predicted_label1 FROM ML.PREDICT(MODEL mydataset.mymodel1, TABLE mydataset.mytable)))"
)
self.validate_identity(
"SELECT * FROM ML.PREDICT(MODEL mydataset.mymodel, (SELECT custom_label, column1, column2 FROM mydataset.mytable), STRUCT(0.55 AS threshold))"
)
self.validate_identity(
"SELECT * FROM ML.PREDICT(MODEL `my_project`.my_dataset.my_model, (SELECT * FROM input_data))"
)
self.validate_identity(
"SELECT * FROM ML.PREDICT(MODEL my_dataset.vision_model, (SELECT uri, ML.RESIZE_IMAGE(ML.DECODE_IMAGE(data), 480, 480, FALSE) AS input FROM my_dataset.object_table))"
)
self.validate_identity(
"SELECT * FROM ML.PREDICT(MODEL my_dataset.vision_model, (SELECT uri, ML.CONVERT_COLOR_SPACE(ML.RESIZE_IMAGE(ML.DECODE_IMAGE(data), 224, 280, TRUE), 'YIQ') AS input FROM my_dataset.object_table WHERE content_type = 'image/jpeg'))"
)
self.validate_identity(
"CREATE OR REPLACE MODEL foo OPTIONS (model_type='linear_reg') AS SELECT bla FROM foo WHERE cond"
)
self.validate_identity(
"""CREATE OR REPLACE MODEL m
TRANSFORM(
ML.FEATURE_CROSS(STRUCT(f1, f2)) AS cross_f,
ML.QUANTILE_BUCKETIZE(f3) OVER () AS buckets,
label_col
)
OPTIONS (
model_type='linear_reg',
input_label_cols=['label_col']
) AS
SELECT
*
FROM t""",
pretty=True,
)
self.validate_identity(
"""CREATE MODEL project_id.mydataset.mymodel
INPUT(
f1 INT64,
f2 FLOAT64,
f3 STRING,
f4 ARRAY<INT64>
)
OUTPUT(
out1 INT64,
out2 INT64
)
REMOTE WITH CONNECTION myproject.us.test_connection
OPTIONS (
ENDPOINT='https://us-central1-aiplatform.googleapis.com/v1/projects/myproject/locations/us-central1/endpoints/1234'
)""",
pretty=True,
)

def test_merge(self):
self.validate_all(
"""
Expand Down

0 comments on commit 3266e51

Please sign in to comment.