Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
americast committed Oct 24, 2023
1 parent 883c598 commit b03c024
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 21 deletions.
32 changes: 16 additions & 16 deletions evadb/binder/statement_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,29 +130,29 @@ def _bind_create_function_statement(self, node: CreateFunctionStatement):
assert (
len(required_columns) == 0
), f"Missing required {required_columns} columns for forecasting function."
outputs.extend(
[
ColumnDefinition(
arg_map.get("predict", "y") + "-lo",
ColumnType.INTEGER,
None,
None,
),
ColumnDefinition(
arg_map.get("predict", "y") + "-hi",
ColumnType.INTEGER,
None,
None,
),
]
)
else:
raise BinderError(
f"Unsupported type of function: {node.function_type}."
)
assert (
len(node.inputs) == 0 and len(node.outputs) == 0
), f"{node.function_type} functions' input and output are auto assigned"
outputs.extend(
[
ColumnDefinition(
arg_map.get("predict", "y") + "-lo",
ColumnType.INTEGER,
None,
None,
),
ColumnDefinition(
arg_map.get("predict", "y") + "-hi",
ColumnType.INTEGER,
None,
None,
),
]
)
node.inputs, node.outputs = inputs, outputs

@bind.register(SelectStatement)
Expand Down
25 changes: 22 additions & 3 deletions test/integration_tests/long/test_model_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,14 @@ def test_forecast(self):
result = execute_query_fetch_all(self.evadb, predict_query)
self.assertEqual(len(result), 12)
self.assertEqual(
result.columns, ["airforecast.unique_id", "airforecast.ds", "airforecast.y"]
result.columns,
[
"airforecast.unique_id",
"airforecast.ds",
"airforecast.y",
"airforecast.y-lo",
"airforecast.y-hi",
],
)

create_predict_udf = """
Expand All @@ -116,7 +123,13 @@ def test_forecast(self):
self.assertEqual(len(result), 24)
self.assertEqual(
result.columns,
["airpanelforecast.unique_id", "airpanelforecast.ds", "airpanelforecast.y"],
[
"airpanelforecast.unique_id",
"airpanelforecast.ds",
"airpanelforecast.y",
"airpanelforecast.y-lo",
"airpanelforecast.y-hi",
],
)

@forecast_skip_marker
Expand All @@ -143,7 +156,13 @@ def test_forecast_with_column_rename(self):
self.assertEqual(len(result), 24)
self.assertEqual(
result.columns,
["homeforecast.type", "homeforecast.saledate", "homeforecast.ma"],
[
"homeforecast.type",
"homeforecast.saledate",
"homeforecast.ma",
"homeforecast.ma-lo",
"homeforecast.ma-hi",
],
)


Expand Down
56 changes: 54 additions & 2 deletions test/unit_tests/binder/test_statement_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,18 @@ def test_bind_create_function_should_bind_forecast_with_default_columns(self):
array_type=MagicMock(),
array_dimensions=MagicMock(),
)
y_lo_col_obj = ColumnCatalogEntry(
name="y-lo",
type=MagicMock(),
array_type=MagicMock(),
array_dimensions=MagicMock(),
)
y_hi_col_obj = ColumnCatalogEntry(
name="y-hi",
type=MagicMock(),
array_type=MagicMock(),
array_dimensions=MagicMock(),
)
create_function_statement.query.target_list = [
TupleValueExpression(
name=id_col_obj.name, table_alias="a", col_object=id_col_obj
Expand Down Expand Up @@ -506,7 +518,21 @@ def test_bind_create_function_should_bind_forecast_with_default_columns(self):
col_obj.array_type,
col_obj.array_dimensions,
)
for col_obj in (id_col_obj, ds_col_obj, y_col_obj)
for col_obj in (
id_col_obj,
ds_col_obj,
y_col_obj,
TupleValueExpression(
name=y_lo_col_obj.name,
table_alias="a",
col_object=y_col_obj,
),
TupleValueExpression(
name=y_hi_col_obj.name,
table_alias="a",
col_object=y_col_obj,
),
)
]
)
self.assertEqual(create_function_statement.inputs, expected_inputs)
Expand Down Expand Up @@ -534,6 +560,18 @@ def test_bind_create_function_should_bind_forecast_with_renaming_columns(self):
array_type=MagicMock(),
array_dimensions=MagicMock(),
)
y_lo_col_obj = ColumnCatalogEntry(
name="y-lo",
type=MagicMock(),
array_type=MagicMock(),
array_dimensions=MagicMock(),
)
y_hi_col_obj = ColumnCatalogEntry(
name="y-hi",
type=MagicMock(),
array_type=MagicMock(),
array_dimensions=MagicMock(),
)
create_function_statement.query.target_list = [
TupleValueExpression(
name=id_col_obj.name, table_alias="a", col_object=id_col_obj
Expand Down Expand Up @@ -569,7 +607,21 @@ def test_bind_create_function_should_bind_forecast_with_renaming_columns(self):
col_obj.array_type,
col_obj.array_dimensions,
)
for col_obj in (id_col_obj, ds_col_obj, y_col_obj)
for col_obj in (
id_col_obj,
ds_col_obj,
y_col_obj,
TupleValueExpression(
name=y_lo_col_obj.name,
table_alias="a",
col_object=y_col_obj,
),
TupleValueExpression(
name=y_hi_col_obj.name,
table_alias="a",
col_object=y_col_obj,
),
)
]
)
self.assertEqual(create_function_statement.inputs, expected_inputs)
Expand Down

0 comments on commit b03c024

Please sign in to comment.