Skip to content

Commit

Permalink
fix: Fix default AutoML Forecasting transformations list.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 526734524
  • Loading branch information
TheMichaelHu authored and copybara-github committed Apr 24, 2023
1 parent 06f8508 commit 77b89c0
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 16 deletions.
4 changes: 3 additions & 1 deletion google/cloud/aiplatform/training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2438,7 +2438,9 @@ def _run(
(
self._column_transformations,
column_names,
) = dataset._get_default_column_transformations(target_column)
) = column_transformations_utils.get_default_column_transformations(
dataset=dataset, target_column=target_column
)

_LOGGER.info(
"The column transformation of type 'auto' was set for the following columns: %s."
Expand Down
21 changes: 6 additions & 15 deletions google/cloud/aiplatform/utils/column_transformations_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#

from typing import Dict, List, Optional, Tuple
import warnings

from google.cloud.aiplatform import datasets

Expand Down Expand Up @@ -51,9 +50,9 @@ def get_default_column_transformations(


def validate_and_get_column_transformations(
column_specs: Optional[Dict[str, str]],
column_transformations: Optional[List[Dict[str, Dict[str, str]]]],
) -> List[Dict[str, Dict[str, str]]]:
column_specs: Optional[Dict[str, str]] = None,
column_transformations: Optional[List[Dict[str, Dict[str, str]]]] = None,
) -> Optional[List[Dict[str, Dict[str, str]]]]:
"""Validates column specs and transformations, then returns processed transformations.
Args:
Expand Down Expand Up @@ -91,21 +90,13 @@ def validate_and_get_column_transformations(
# user populated transformations
if column_transformations is not None and column_specs is not None:
raise ValueError(
"Both column_transformations and column_specs were passed. Only one is allowed."
"Both column_transformations and column_specs were passed. Only "
"one is allowed."
)
if column_transformations is not None:
warnings.simplefilter("always", DeprecationWarning)
warnings.warn(
"consider using column_specs instead. column_transformations will be deprecated in the future.",
DeprecationWarning,
stacklevel=2,
)

return column_transformations
elif column_specs is not None:
return [
{transformation: {"column_name": column_name}}
for column_name, transformation in column_specs.items()
]
else:
return None
return column_transformations
38 changes: 38 additions & 0 deletions tests/unit/aiplatform/test_automl_forecasting_training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,3 +1294,41 @@ def test_run_call_pipeline_if_set_additional_experiments_probabilistic_inference
training_pipeline=true_training_pipeline,
timeout=None,
)

def test_automl_forecasting_with_no_transformations(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_dataset_time_series,
mock_model_service_get,
):
aiplatform.init(project=_TEST_PROJECT)
job = training_jobs.AutoMLForecastingTrainingJob(
display_name=_TEST_DISPLAY_NAME,
optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME,
)
mock_dataset_time_series.column_names = [
"a",
"b",
_TEST_TRAINING_TARGET_COLUMN,
]
job.run(
dataset=mock_dataset_time_series,
predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME,
target_column=_TEST_TRAINING_TARGET_COLUMN,
time_column=_TEST_TRAINING_TIME_COLUMN,
time_series_identifier_column=_TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN,
unavailable_at_forecast_columns=_TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS,
available_at_forecast_columns=_TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS,
forecast_horizon=_TEST_TRAINING_FORECAST_HORIZON,
data_granularity_unit=_TEST_TRAINING_DATA_GRANULARITY_UNIT,
data_granularity_count=_TEST_TRAINING_DATA_GRANULARITY_COUNT,
model_display_name=_TEST_MODEL_DISPLAY_NAME,
time_series_attribute_columns=_TEST_TRAINING_TIME_SERIES_ATTRIBUTE_COLUMNS,
context_window=_TEST_TRAINING_CONTEXT_WINDOW,
budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS,
)
assert job._column_transformations == [
{"auto": {"column_name": "a"}},
{"auto": {"column_name": "b"}},
]
47 changes: 47 additions & 0 deletions tests/unit/aiplatform/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
from google.cloud import storage
from google.cloud.aiplatform import compat, utils
from google.cloud.aiplatform.compat.types import pipeline_failure_policy
from google.cloud.aiplatform import datasets
from google.cloud.aiplatform.utils import (
column_transformations_utils,
gcs_utils,
pipeline_utils,
prediction_utils,
Expand Down Expand Up @@ -485,6 +487,51 @@ def test_timestamped_unique_name():
assert re.match(r"\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}-.{5}", name)


class TestColumnTransformationsUtils:

column_transformations = [
{"auto": {"column_name": "a"}},
{"auto": {"column_name": "b"}},
]
column_specs = {"a": "auto", "b": "auto"}

def test_get_default_column_transformations(self):
ds = mock.MagicMock(datasets.TimeSeriesDataset)
ds.column_names = ["a", "b", "target"]
(
transforms,
columns,
) = column_transformations_utils.get_default_column_transformations(
dataset=ds, target_column="target"
)
assert transforms == [
{"auto": {"column_name": "a"}},
{"auto": {"column_name": "b"}},
]
assert columns == ["a", "b"]

def test_validate_transformations_with_multiple_configs(self):
with pytest.raises(ValueError):
(
column_transformations_utils.validate_and_get_column_transformations(
column_transformations=self.column_transformations,
column_specs=self.column_specs,
)
)

def test_validate_transformations_with_column_specs(self):
actual = column_transformations_utils.validate_and_get_column_transformations(
column_specs=self.column_specs
)
assert actual == self.column_transformations

def test_validate_transformations_with_column_transformations(self):
actual = column_transformations_utils.validate_and_get_column_transformations(
column_transformations=self.column_transformations
)
assert actual == self.column_transformations


@pytest.mark.usefixtures("google_auth_mock")
class TestGcsUtils:
def test_upload_to_gcs(self, json_file, mock_storage_blob_upload_from_filename):
Expand Down

0 comments on commit 77b89c0

Please sign in to comment.