Skip to content

Commit

Permalink
vertex ai training operators: add display_name to rendered fields (#4…
Browse files Browse the repository at this point in the history
…3028)

* vertex ai training operators: add display_name to rendered fields

* fix validate-operators-init static checks

---------

Co-authored-by: Walter Hoekstra <walterhoekstra@bol.com>
Co-authored-by: Shahar Epstein <60007259+shahar1@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 17, 2024
1 parent b3b4850 commit f587edf
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ class CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
"dataset_id",
"region",
"impersonation_chain",
"display_name",
"model_display_name",
)
operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink())

Expand All @@ -121,6 +123,8 @@ def __init__(
forecast_horizon: int,
data_granularity_unit: str,
data_granularity_count: int,
display_name: str,
model_display_name: str | None = None,
optimization_objective: str | None = None,
column_specs: dict[str, str] | None = None,
column_transformations: list[dict[str, dict[str, str]]] | None = None,
Expand All @@ -143,7 +147,12 @@ def __init__(
**kwargs,
) -> None:
super().__init__(
region=region, impersonation_chain=impersonation_chain, parent_model=parent_model, **kwargs
display_name=display_name,
model_display_name=model_display_name,
region=region,
impersonation_chain=impersonation_chain,
parent_model=parent_model,
**kwargs,
)
self.dataset_id = dataset_id
self.target_column = target_column
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
:param poll_interval: Interval size which defines how often job status is checked in deferrable mode.
"""

template_fields = ("region", "project_id", "model_name", "impersonation_chain")
template_fields = ("region", "project_id", "model_name", "impersonation_chain", "job_display_name")
operator_extra_links = (VertexAIBatchPredictionJobLink(),)

def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,8 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
"parent_model",
"dataset_id",
"impersonation_chain",
"display_name",
"model_display_name",
)
operator_extra_links = (
VertexAIModelLink(),
Expand All @@ -507,6 +509,8 @@ def __init__(
*,
command: Sequence[str] = [],
region: str,
display_name: str,
model_display_name: str | None = None,
parent_model: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
dataset_id: str | None = None,
Expand All @@ -515,6 +519,8 @@ def __init__(
**kwargs,
) -> None:
super().__init__(
display_name=display_name,
model_display_name=model_display_name,
region=region,
parent_model=parent_model,
impersonation_chain=impersonation_chain,
Expand Down Expand Up @@ -949,6 +955,8 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
"region",
"dataset_id",
"impersonation_chain",
"display_name",
"model_display_name",
)
operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink())

Expand All @@ -958,6 +966,8 @@ def __init__(
python_package_gcs_uri: str,
python_module_name: str,
region: str,
display_name: str,
model_display_name: str | None = None,
parent_model: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
dataset_id: str | None = None,
Expand All @@ -966,6 +976,8 @@ def __init__(
**kwargs,
) -> None:
super().__init__(
display_name=display_name,
model_display_name=model_display_name,
region=region,
parent_model=parent_model,
impersonation_chain=impersonation_chain,
Expand Down Expand Up @@ -1405,6 +1417,8 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
"requirements",
"dataset_id",
"impersonation_chain",
"display_name",
"model_display_name",
)
operator_extra_links = (
VertexAIModelLink(),
Expand All @@ -1417,6 +1431,8 @@ def __init__(
script_path: str,
requirements: Sequence[str] | None = None,
region: str,
display_name: str,
model_display_name: str | None = None,
parent_model: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
dataset_id: str | None = None,
Expand All @@ -1425,6 +1441,8 @@ def __init__(
**kwargs,
) -> None:
super().__init__(
display_name=display_name,
model_display_name=model_display_name,
region=region,
parent_model=parent_model,
impersonation_chain=impersonation_chain,
Expand Down Expand Up @@ -1732,6 +1750,7 @@ class ListCustomTrainingJobOperator(GoogleCloudBaseOperator):
"region",
"project_id",
"impersonation_chain",
"display_name",
]
operator_extra_links = [
VertexAITrainingPipelinesLink(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
"region",
"project_id",
"impersonation_chain",
"display_name",
]
operator_extra_links = (VertexAITrainingLink(),)

Expand Down

0 comments on commit f587edf

Please sign in to comment.