From cd102bb810ec46fb41e9d72a4f3358b8f8df1458 Mon Sep 17 00:00:00 2001 From: Shahar Epstein Date: Fri, 22 Mar 2024 22:30:56 +0200 Subject: [PATCH] Fix `parent_model` parameter in GCP Vertex AI AutoML and Custom Job operators --- .../cloud/operators/vertex_ai/auto_ml.py | 6 +-- .../cloud/operators/vertex_ai/custom_job.py | 5 -- .../google/cloud/operators/test_vertex_ai.py | 46 +++++++++++++------ 3 files changed, 33 insertions(+), 24 deletions(-) diff --git a/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py index 5269c4db255375..828380d147c5af 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py @@ -176,7 +176,6 @@ def execute(self, context: Context): gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) - self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None model, training_id = self.hook.create_auto_ml_forecasting_training_job( project_id=self.project_id, region=self.region, @@ -284,7 +283,6 @@ def execute(self, context: Context): gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) - self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None model, training_id = self.hook.create_auto_ml_image_training_job( project_id=self.project_id, region=self.region, @@ -393,7 +391,6 @@ def execute(self, context: Context): impersonation_chain=self.impersonation_chain, ) credentials, _ = self.hook.get_credentials_and_project_id() - self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None model, training_id = self.hook.create_auto_ml_tabular_training_job( project_id=self.project_id, region=self.region, @@ -488,7 +485,7 @@ def execute(self, context: Context): gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) - self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None + print(self.parent_model) model, training_id = self.hook.create_auto_ml_text_training_job( project_id=self.project_id, region=self.region, @@ -565,7 +562,6 @@ def execute(self, context: Context): gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) - self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None model, training_id = self.hook.create_auto_ml_video_training_job( project_id=self.project_id, region=self.region, diff --git a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py index ba34f633acb31c..8802c3a26a05f7 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py @@ -468,8 +468,6 @@ def execute(self, context: Context): gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) - self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None - model, training_id, custom_job_id = self.hook.create_custom_container_training_job( project_id=self.project_id, region=self.region, @@ -850,7 +848,6 @@ def execute(self, context: Context): gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) - self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None model, training_id, custom_job_id = self.hook.create_custom_python_package_training_job( project_id=self.project_id, region=self.region, @@ -1234,8 +1231,6 @@ def execute(self, context: Context): gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) - self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None - model, training_id, custom_job_id = self.hook.create_custom_training_job( project_id=self.project_id, region=self.region, diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py b/tests/providers/google/cloud/operators/test_vertex_ai.py index 3f1452547144ff..92cd6457db6488 100644 --- a/tests/providers/google/cloud/operators/test_vertex_ai.py +++ b/tests/providers/google/cloud/operators/test_vertex_ai.py @@ -130,6 +130,7 @@ "metadata": "test-image-dataset", } TEST_DATASET_ID = "test-dataset-id" +TEST_PARENT_MODEL = "test-parent-model" TEST_EXPORT_CONFIG = { "annotationsFilter": "test-filter", "gcs_destination": {"output_uri_prefix": "airflow-system-tests-data"}, @@ -190,8 +191,9 @@ class TestVertexAICreateCustomContainerTrainingJobOperator: + @mock.patch(VERTEX_AI_PATH.format("custom_job.Dataset")) @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook")) - def test_execute(self, mock_hook): + def test_execute(self, mock_hook, mock_dataset): mock_hook.return_value.create_custom_container_training_job.return_value = ( None, "training_id", @@ -217,8 +219,11 @@ def test_execute(self, mock_hook): test_fraction_split=TEST_FRACTION_SPLIT, region=GCP_LOCATION, project_id=GCP_PROJECT, + dataset_id=TEST_DATASET_ID, + parent_model=TEST_PARENT_MODEL, ) op.execute(context={"ti": mock.MagicMock()}) + mock_dataset.assert_called_once_with(name=TEST_DATASET_ID) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.create_custom_container_training_job.assert_called_once_with( staging_bucket=STAGING_BUCKET, @@ -227,7 +232,7 @@ def test_execute(self, mock_hook): container_uri=CONTAINER_URI, model_serving_container_image_uri=CONTAINER_URI, command=COMMAND_2, - dataset=None, + dataset=mock_dataset.return_value, model_display_name=DISPLAY_NAME_2, replica_count=REPLICA_COUNT, machine_type=MACHINE_TYPE, @@ -238,7 +243,7 @@ def test_execute(self, mock_hook): test_fraction_split=TEST_FRACTION_SPLIT, region=GCP_LOCATION, project_id=GCP_PROJECT, - parent_model=None, + parent_model=TEST_PARENT_MODEL, model_serving_container_predict_route=None, model_serving_container_health_route=None, model_serving_container_command=None, @@ -276,8 +281,9 @@ def test_execute(self, mock_hook): class TestVertexAICreateCustomPythonPackageTrainingJobOperator: + @mock.patch(VERTEX_AI_PATH.format("custom_job.Dataset")) @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook")) - def test_execute(self, mock_hook): + def test_execute(self, mock_hook, mock_dataset): mock_hook.return_value.create_custom_python_package_training_job.return_value = ( None, "training_id", @@ -304,8 +310,11 @@ def test_execute(self, mock_hook): test_fraction_split=TEST_FRACTION_SPLIT, region=GCP_LOCATION, project_id=GCP_PROJECT, + dataset_id=TEST_DATASET_ID, + parent_model=TEST_PARENT_MODEL, ) op.execute(context={"ti": mock.MagicMock()}) + mock_dataset.assert_called_once_with(name=TEST_DATASET_ID) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.create_custom_python_package_training_job.assert_called_once_with( staging_bucket=STAGING_BUCKET, @@ -315,7 +324,7 @@ def test_execute(self, mock_hook): model_serving_container_image_uri=CONTAINER_URI, python_package_gcs_uri=PYTHON_PACKAGE_GCS_URI, python_module_name=PYTHON_MODULE_NAME, - dataset=None, + dataset=mock_dataset.return_value, model_display_name=DISPLAY_NAME_2, replica_count=REPLICA_COUNT, machine_type=MACHINE_TYPE, @@ -326,7 +335,7 @@ def test_execute(self, mock_hook): test_fraction_split=TEST_FRACTION_SPLIT, region=GCP_LOCATION, project_id=GCP_PROJECT, - parent_model=None, + parent_model=TEST_PARENT_MODEL, is_default_version=None, model_version_aliases=None, model_version_description=None, @@ -364,8 +373,9 @@ def test_execute(self, mock_hook): class TestVertexAICreateCustomTrainingJobOperator: + @mock.patch(VERTEX_AI_PATH.format("custom_job.Dataset")) @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook")) - def test_execute(self, mock_hook): + def test_execute(self, mock_hook, mock_dataset): mock_hook.return_value.create_custom_training_job.return_value = ( None, "training_id", @@ -385,9 +395,12 @@ def test_execute(self, mock_hook): replica_count=1, region=GCP_LOCATION, project_id=GCP_PROJECT, + dataset_id=TEST_DATASET_ID, + parent_model=TEST_PARENT_MODEL, ) op.execute(context={"ti": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_dataset.assert_called_once_with(name=TEST_DATASET_ID) mock_hook.return_value.create_custom_training_job.assert_called_once_with( staging_bucket=STAGING_BUCKET, display_name=DISPLAY_NAME, @@ -396,7 +409,7 @@ def test_execute(self, mock_hook): model_serving_container_image_uri=CONTAINER_URI, script_path=PYTHON_PACKAGE, requirements=[], - dataset=None, + dataset=mock_dataset.return_value, model_display_name=None, replica_count=REPLICA_COUNT, machine_type=MACHINE_TYPE, @@ -405,7 +418,7 @@ def test_execute(self, mock_hook): training_fraction_split=None, validation_fraction_split=None, test_fraction_split=None, - parent_model=None, + parent_model=TEST_PARENT_MODEL, region=GCP_LOCATION, project_id=GCP_PROJECT, model_serving_container_predict_route=None, @@ -751,6 +764,7 @@ def test_execute(self, mock_hook, mock_dataset): sync=True, region=GCP_LOCATION, project_id=GCP_PROJECT, + parent_model=TEST_PARENT_MODEL, ) op.execute(context={"ti": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) @@ -768,7 +782,7 @@ def test_execute(self, mock_hook, mock_dataset): forecast_horizon=TEST_TRAINING_FORECAST_HORIZON, data_granularity_unit=TEST_TRAINING_DATA_GRANULARITY_UNIT, data_granularity_count=TEST_TRAINING_DATA_GRANULARITY_COUNT, - parent_model=None, + parent_model=TEST_PARENT_MODEL, optimization_objective=None, column_specs=None, column_transformations=None, @@ -814,6 +828,7 @@ def test_execute(self, mock_hook, mock_dataset): sync=True, region=GCP_LOCATION, project_id=GCP_PROJECT, + parent_model=TEST_PARENT_MODEL, ) op.execute(context={"ti": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) @@ -824,7 +839,7 @@ def test_execute(self, mock_hook, mock_dataset): display_name=DISPLAY_NAME, dataset=mock_dataset.return_value, prediction_type="classification", - parent_model=None, + parent_model=TEST_PARENT_MODEL, multi_label=False, model_type="CLOUD", base_model=None, @@ -869,6 +884,7 @@ def test_execute(self, mock_hook, mock_dataset): sync=True, region=GCP_LOCATION, project_id=GCP_PROJECT, + parent_model=TEST_PARENT_MODEL, ) op.execute(context={"ti": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) @@ -880,7 +896,7 @@ def test_execute(self, mock_hook, mock_dataset): region=GCP_LOCATION, display_name=DISPLAY_NAME, dataset=mock_dataset.return_value, - parent_model=None, + parent_model=TEST_PARENT_MODEL, target_column=None, optimization_prediction_type=None, optimization_objective=None, @@ -928,6 +944,7 @@ def test_execute(self, mock_hook, mock_dataset): sync=True, region=GCP_LOCATION, project_id=GCP_PROJECT, + parent_model=TEST_PARENT_MODEL, ) op.execute(context={"ti": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) @@ -937,7 +954,7 @@ def test_execute(self, mock_hook, mock_dataset): region=GCP_LOCATION, display_name=DISPLAY_NAME, dataset=mock_dataset.return_value, - parent_model=None, + parent_model=TEST_PARENT_MODEL, prediction_type=None, multi_label=False, sentiment_max=10, @@ -975,6 +992,7 @@ def test_execute(self, mock_hook, mock_dataset): sync=True, region=GCP_LOCATION, project_id=GCP_PROJECT, + parent_model=TEST_PARENT_MODEL, ) op.execute(context={"ti": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) @@ -984,7 +1002,7 @@ def test_execute(self, mock_hook, mock_dataset): region=GCP_LOCATION, display_name=DISPLAY_NAME, dataset=mock_dataset.return_value, - parent_model=None, + parent_model=TEST_PARENT_MODEL, prediction_type="classification", model_type="CLOUD", labels=None,