From f587edf2ee74169f1da6a408a04d7ed3c372d270 Mon Sep 17 00:00:00 2001 From: WS Hoekstra Date: Thu, 17 Oct 2024 09:43:55 +0200 Subject: [PATCH] vertex ai training operators: add display_name to rendered fields (#43028) * vertex ai training operators: add display_name to rendered fields * fix validate-operators-init static checks --------- Co-authored-by: Walter Hoekstra Co-authored-by: Shahar Epstein <60007259+shahar1@users.noreply.github.com> --- .../cloud/operators/vertex_ai/auto_ml.py | 11 ++++++++++- .../vertex_ai/batch_prediction_job.py | 2 +- .../cloud/operators/vertex_ai/custom_job.py | 19 +++++++++++++++++++ .../vertex_ai/hyperparameter_tuning_job.py | 1 + 4 files changed, 31 insertions(+), 2 deletions(-) diff --git a/providers/src/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py b/providers/src/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py index ec9a6784f79ff..86b3ae017c080 100644 --- a/providers/src/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +++ b/providers/src/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py @@ -106,6 +106,8 @@ class CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator): "dataset_id", "region", "impersonation_chain", + "display_name", + "model_display_name", ) operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink()) @@ -121,6 +123,8 @@ def __init__( forecast_horizon: int, data_granularity_unit: str, data_granularity_count: int, + display_name: str, + model_display_name: str | None = None, optimization_objective: str | None = None, column_specs: dict[str, str] | None = None, column_transformations: list[dict[str, dict[str, str]]] | None = None, @@ -143,7 +147,12 @@ def __init__( **kwargs, ) -> None: super().__init__( - region=region, impersonation_chain=impersonation_chain, parent_model=parent_model, **kwargs + display_name=display_name, + model_display_name=model_display_name, + region=region, + impersonation_chain=impersonation_chain, + parent_model=parent_model, + **kwargs, ) self.dataset_id = dataset_id self.target_column = target_column diff --git a/providers/src/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py b/providers/src/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py index a57a072fb68de..60ddd747eec56 100644 --- a/providers/src/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +++ b/providers/src/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py @@ -163,7 +163,7 @@ class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator): :param poll_interval: Interval size which defines how often job status is checked in deferrable mode. """ - template_fields = ("region", "project_id", "model_name", "impersonation_chain") + template_fields = ("region", "project_id", "model_name", "impersonation_chain", "job_display_name") operator_extra_links = (VertexAIBatchPredictionJobLink(),) def __init__( diff --git a/providers/src/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py b/providers/src/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py index 854101a72e34d..33ab07d075396 100644 --- a/providers/src/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +++ b/providers/src/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py @@ -496,6 +496,8 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator): "parent_model", "dataset_id", "impersonation_chain", + "display_name", + "model_display_name", ) operator_extra_links = ( VertexAIModelLink(), @@ -507,6 +509,8 @@ def __init__( *, command: Sequence[str] = [], region: str, + display_name: str, + model_display_name: str | None = None, parent_model: str | None = None, impersonation_chain: str | Sequence[str] | None = None, dataset_id: str | None = None, @@ -515,6 +519,8 @@ def __init__( **kwargs, ) -> None: super().__init__( + display_name=display_name, + model_display_name=model_display_name, region=region, parent_model=parent_model, impersonation_chain=impersonation_chain, @@ -949,6 +955,8 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator "region", "dataset_id", "impersonation_chain", + "display_name", + "model_display_name", ) operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink()) @@ -958,6 +966,8 @@ def __init__( python_package_gcs_uri: str, python_module_name: str, region: str, + display_name: str, + model_display_name: str | None = None, parent_model: str | None = None, impersonation_chain: str | Sequence[str] | None = None, dataset_id: str | None = None, @@ -966,6 +976,8 @@ def __init__( **kwargs, ) -> None: super().__init__( + display_name=display_name, + model_display_name=model_display_name, region=region, parent_model=parent_model, impersonation_chain=impersonation_chain, @@ -1405,6 +1417,8 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator): "requirements", "dataset_id", "impersonation_chain", + "display_name", + "model_display_name", ) operator_extra_links = ( VertexAIModelLink(), @@ -1417,6 +1431,8 @@ def __init__( script_path: str, requirements: Sequence[str] | None = None, region: str, + display_name: str, + model_display_name: str | None = None, parent_model: str | None = None, impersonation_chain: str | Sequence[str] | None = None, dataset_id: str | None = None, @@ -1425,6 +1441,8 @@ def __init__( **kwargs, ) -> None: super().__init__( + display_name=display_name, + model_display_name=model_display_name, region=region, parent_model=parent_model, impersonation_chain=impersonation_chain, @@ -1732,6 +1750,7 @@ class ListCustomTrainingJobOperator(GoogleCloudBaseOperator): "region", "project_id", "impersonation_chain", + "display_name", ] operator_extra_links = [ VertexAITrainingPipelinesLink(), diff --git a/providers/src/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py b/providers/src/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py index 2da2bf33946b8..43229e033cc8f 100644 --- a/providers/src/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +++ b/providers/src/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py @@ -147,6 +147,7 @@ class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator): "region", "project_id", "impersonation_chain", + "display_name", ] operator_extra_links = (VertexAITrainingLink(),)