Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
americast committed Oct 24, 2023
1 parent 883c598 commit 4f6ea4a
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 24 deletions.
32 changes: 16 additions & 16 deletions evadb/binder/statement_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,29 +130,29 @@ def _bind_create_function_statement(self, node: CreateFunctionStatement):
assert (
len(required_columns) == 0
), f"Missing required {required_columns} columns for forecasting function."
outputs.extend(
[
ColumnDefinition(
arg_map.get("predict", "y") + "-lo",
ColumnType.FLOAT,
None,
None,
),
ColumnDefinition(
arg_map.get("predict", "y") + "-hi",
ColumnType.FLOAT,
None,
None,
),
]
)
else:
raise BinderError(
f"Unsupported type of function: {node.function_type}."
)
assert (
len(node.inputs) == 0 and len(node.outputs) == 0
), f"{node.function_type} functions' input and output are auto assigned"
outputs.extend(
[
ColumnDefinition(
arg_map.get("predict", "y") + "-lo",
ColumnType.INTEGER,
None,
None,
),
ColumnDefinition(
arg_map.get("predict", "y") + "-hi",
ColumnType.INTEGER,
None,
None,
),
]
)
node.inputs, node.outputs = inputs, outputs

@bind.register(SelectStatement)
Expand Down
2 changes: 1 addition & 1 deletion evadb/executor/create_function_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def get_optuna_config(trial):
model_save_dir_name = (
library + "_" + arg_map["model"] + "_" + new_freq
if "statsforecast" in library
else library + "_" + conf + "_" + arg_map["model"] + "_" + new_freq
else library + "_" + str(conf) + "_" + arg_map["model"] + "_" + new_freq
)
if len(data.columns) >= 4 and library == "neuralforecast":
model_save_dir_name += "_exogenous_" + str(sorted(exogenous_columns))
Expand Down
5 changes: 3 additions & 2 deletions evadb/functions/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,13 @@ def forward(self, data) -> pd.DataFrame:

for suggestion in set(suggestion_list):
print("\nSUGGESTION: " + self.suggestion_dict[suggestion])

forecast_df = forecast_df.rename(
columns={
"unique_id": self.id_column_rename,
"ds": self.time_column_rename,
self.model_name: self.predict_column_rename,
self.model_name
if self.library == "statsforecast"
else self.model_name + "-median": self.predict_column_rename,
self.model_name
+ "-lo-"
+ str(self.conf): self.predict_column_rename
Expand Down
25 changes: 22 additions & 3 deletions test/integration_tests/long/test_model_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,14 @@ def test_forecast(self):
result = execute_query_fetch_all(self.evadb, predict_query)
self.assertEqual(len(result), 12)
self.assertEqual(
result.columns, ["airforecast.unique_id", "airforecast.ds", "airforecast.y"]
result.columns,
[
"airforecast.unique_id",
"airforecast.ds",
"airforecast.y",
"airforecast.y-lo",
"airforecast.y-hi",
],
)

create_predict_udf = """
Expand All @@ -116,7 +123,13 @@ def test_forecast(self):
self.assertEqual(len(result), 24)
self.assertEqual(
result.columns,
["airpanelforecast.unique_id", "airpanelforecast.ds", "airpanelforecast.y"],
[
"airpanelforecast.unique_id",
"airpanelforecast.ds",
"airpanelforecast.y",
"airpanelforecast.y-lo",
"airpanelforecast.y-hi",
],
)

@forecast_skip_marker
Expand All @@ -143,7 +156,13 @@ def test_forecast_with_column_rename(self):
self.assertEqual(len(result), 24)
self.assertEqual(
result.columns,
["homeforecast.type", "homeforecast.saledate", "homeforecast.ma"],
[
"homeforecast.type",
"homeforecast.saledate",
"homeforecast.ma",
"homeforecast.ma-lo",
"homeforecast.ma-hi",
],
)


Expand Down
37 changes: 35 additions & 2 deletions test/unit_tests/binder/test_statement_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,16 @@ def test_bind_create_function_should_bind_forecast_with_default_columns(self):
array_type=MagicMock(),
array_dimensions=MagicMock(),
)
y_lo_col_obj = ColumnCatalogEntry(
name="y-lo",
type=ColumnType.FLOAT,
array_type=None,
)
y_hi_col_obj = ColumnCatalogEntry(
name="y-hi",
type=ColumnType.FLOAT,
array_type=None,
)
create_function_statement.query.target_list = [
TupleValueExpression(
name=id_col_obj.name, table_alias="a", col_object=id_col_obj
Expand Down Expand Up @@ -506,9 +516,16 @@ def test_bind_create_function_should_bind_forecast_with_default_columns(self):
col_obj.array_type,
col_obj.array_dimensions,
)
for col_obj in (id_col_obj, ds_col_obj, y_col_obj)
for col_obj in (
id_col_obj,
ds_col_obj,
y_col_obj,
y_lo_col_obj,
y_hi_col_obj,
)
]
)
print(create_function_statement.outputs)
self.assertEqual(create_function_statement.inputs, expected_inputs)
self.assertEqual(create_function_statement.outputs, expected_outputs)

Expand All @@ -534,6 +551,16 @@ def test_bind_create_function_should_bind_forecast_with_renaming_columns(self):
array_type=MagicMock(),
array_dimensions=MagicMock(),
)
y_lo_col_obj = ColumnCatalogEntry(
name="ma-lo",
type=ColumnType.FLOAT,
array_type=None,
)
y_hi_col_obj = ColumnCatalogEntry(
name="ma-hi",
type=ColumnType.FLOAT,
array_type=None,
)
create_function_statement.query.target_list = [
TupleValueExpression(
name=id_col_obj.name, table_alias="a", col_object=id_col_obj
Expand Down Expand Up @@ -569,7 +596,13 @@ def test_bind_create_function_should_bind_forecast_with_renaming_columns(self):
col_obj.array_type,
col_obj.array_dimensions,
)
for col_obj in (id_col_obj, ds_col_obj, y_col_obj)
for col_obj in (
id_col_obj,
ds_col_obj,
y_col_obj,
y_lo_col_obj,
y_hi_col_obj,
)
]
)
self.assertEqual(create_function_statement.inputs, expected_inputs)
Expand Down

0 comments on commit 4f6ea4a

Please sign in to comment.