Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for failure_policy in PipelineJob #1452

Merged
merged 8 commits into from
Jun 23, 2022
2 changes: 2 additions & 0 deletions google/cloud/aiplatform/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
types.model_evaluation_slice = types.model_evaluation_slice_v1beta1
types.model_service = types.model_service_v1beta1
types.operation = types.operation_v1beta1
types.pipeline_failure_policy = types.pipeline_failure_policy_v1beta1
types.pipeline_job = types.pipeline_job_v1beta1
types.pipeline_service = types.pipeline_service_v1beta1
types.pipeline_state = types.pipeline_state_v1beta1
Expand Down Expand Up @@ -180,6 +181,7 @@
types.model_evaluation_slice = types.model_evaluation_slice_v1
types.model_service = types.model_service_v1
types.operation = types.operation_v1
types.pipeline_failure_policy = types.pipeline_failure_policy_v1
types.pipeline_job = types.pipeline_job_v1
types.pipeline_service = types.pipeline_service_v1
types.pipeline_state = types.pipeline_state_v1
Expand Down
4 changes: 4 additions & 0 deletions google/cloud/aiplatform/compat/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
model_evaluation_slice as model_evaluation_slice_v1beta1,
model_service as model_service_v1beta1,
operation as operation_v1beta1,
pipeline_failure_policy as pipeline_failure_policy_v1beta1,
pipeline_job as pipeline_job_v1beta1,
pipeline_service as pipeline_service_v1beta1,
pipeline_state as pipeline_state_v1beta1,
Expand Down Expand Up @@ -126,6 +127,7 @@
model_evaluation_slice as model_evaluation_slice_v1,
model_service as model_service_v1,
operation as operation_v1,
pipeline_failure_policy as pipeline_failure_policy_v1,
pipeline_job as pipeline_job_v1,
pipeline_service as pipeline_service_v1,
pipeline_state as pipeline_state_v1,
Expand Down Expand Up @@ -191,6 +193,7 @@
model_evaluation_slice_v1,
model_service_v1,
operation_v1,
pipeline_failure_policy_v1beta1,
pipeline_job_v1,
pipeline_service_v1,
pipeline_state_v1,
Expand Down Expand Up @@ -253,6 +256,7 @@
model_evaluation_slice_v1beta1,
model_service_v1beta1,
operation_v1beta1,
pipeline_failure_policy_v1beta1,
pipeline_job_v1beta1,
pipeline_service_v1beta1,
pipeline_state_v1beta1,
Expand Down
11 changes: 11 additions & 0 deletions google/cloud/aiplatform/pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(
credentials: Optional[auth_credentials.Credentials] = None,
project: Optional[str] = None,
location: Optional[str] = None,
failure_policy: Optional[str] = None,
):
"""Retrieves a PipelineJob resource and instantiates its
representation.
Expand Down Expand Up @@ -173,6 +174,15 @@ def __init__(
location (str):
Optional. Location to create PipelineJob. If not set,
location set in aiplatform.init will be used.
failure_policy (str):
Optional. The failure policy - "slow" or "fast".
Currently, the default of a pipeline is that the pipeline will continue to
run until no more tasks can be executed, also known as
PIPELINE_FAILURE_POLICY_FAIL_SLOW (corresponds to "slow").
However, if a pipeline is set to
PIPELINE_FAILURE_POLICY_FAIL_FAST (corresponds to "fast"),
it will stop scheduling any new tasks when a task has failed. Any
scheduled tasks will continue to completion.
sararob marked this conversation as resolved.
Show resolved Hide resolved

Raises:
ValueError: If job_id or labels have incorrect format.
Expand Down Expand Up @@ -219,6 +229,7 @@ def __init__(
)
builder.update_pipeline_root(pipeline_root)
builder.update_runtime_parameters(parameter_values)
builder.update_failure_policy(failure_policy)
runtime_config_dict = builder.build()

runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
Expand Down
45 changes: 43 additions & 2 deletions google/cloud/aiplatform/utils/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import copy
import json
from typing import Any, Dict, Mapping, Optional, Union
from google.cloud.aiplatform.compat.types import pipeline_failure_policy
import packaging.version


Expand All @@ -32,6 +33,7 @@ def __init__(
schema_version: str,
parameter_types: Mapping[str, str],
parameter_values: Optional[Dict[str, Any]] = None,
failure_policy: Optional[pipeline_failure_policy.PipelineFailurePolicy] = None,
):
"""Creates a PipelineRuntimeConfigBuilder object.

Expand All @@ -44,11 +46,20 @@ def __init__(
Required. The mapping from pipeline parameter name to its type.
parameter_values (Dict[str, Any]):
Optional. The mapping from runtime parameter name to its value.
failure_policy (pipeline_failure_policy.PipelineFailurePolicy):
Optional. Represents the failure policy of a pipeline. Currently, the
default of a pipeline is that the pipeline will continue to
run until no more tasks can be executed, also known as
PIPELINE_FAILURE_POLICY_FAIL_SLOW. However, if a pipeline is
set to PIPELINE_FAILURE_POLICY_FAIL_FAST, it will stop
scheduling any new tasks when a task has failed. Any
scheduled tasks will continue to completion.
"""
self._pipeline_root = pipeline_root
self._schema_version = schema_version
self._parameter_types = parameter_types
self._parameter_values = copy.deepcopy(parameter_values or {})
self._failure_policy = failure_policy

@classmethod
def from_job_spec_json(
Expand Down Expand Up @@ -80,7 +91,14 @@ def from_job_spec_json(

pipeline_root = runtime_config_spec.get("gcsOutputDirectory")
parameter_values = _parse_runtime_parameters(runtime_config_spec)
return cls(pipeline_root, schema_version, parameter_types, parameter_values)
failure_policy = runtime_config_spec.get("failurePolicy")
return cls(
pipeline_root,
schema_version,
parameter_types,
parameter_values,
failure_policy,
)

def update_pipeline_root(self, pipeline_root: Optional[str]) -> None:
"""Updates pipeline_root value.
Expand Down Expand Up @@ -111,6 +129,16 @@ def update_runtime_parameters(
parameters[k] = json.dumps(v)
self._parameter_values.update(parameters)

def update_failure_policy(self, failure_policy: Optional[str] = None) -> None:
"""Merges runtime failure policy.

Args:
failure_policy (str):
Optional. The failure policy - "slow" or "fast".
"""
if failure_policy:
self._failure_policy = _FAILURE_POLICY_TO_ENUM_VALUE[failure_policy]
chongyouquan marked this conversation as resolved.
Show resolved Hide resolved

def build(self) -> Dict[str, Any]:
"""Build a RuntimeConfig proto.

Expand All @@ -128,7 +156,8 @@ def build(self) -> Dict[str, Any]:
parameter_values_key = "parameterValues"
else:
parameter_values_key = "parameters"
return {

runtime_config = {
"gcsOutputDirectory": self._pipeline_root,
parameter_values_key: {
k: self._get_vertex_value(k, v)
Expand All @@ -137,6 +166,11 @@ def build(self) -> Dict[str, Any]:
},
}

if self._failure_policy:
runtime_config["failurePolicy"] = self._failure_policy

return runtime_config

def _get_vertex_value(
self, name: str, value: Union[int, float, str, bool, list, dict]
) -> Union[int, float, str, bool, list, dict]:
Expand Down Expand Up @@ -205,3 +239,10 @@ def _parse_runtime_parameters(
else:
raise TypeError("Got unknown type of value: {}".format(value))
return result


_FAILURE_POLICY_TO_ENUM_VALUE = {
"slow": pipeline_failure_policy.PipelineFailurePolicy.PIPELINE_FAILURE_POLICY_FAIL_SLOW,
"fast": pipeline_failure_policy.PipelineFailurePolicy.PIPELINE_FAILURE_POLICY_FAIL_FAST,
None: pipeline_failure_policy.PipelineFailurePolicy.PIPELINE_FAILURE_POLICY_UNSPECIFIED,
}
94 changes: 94 additions & 0 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import pipeline_jobs
from google.cloud.aiplatform.compat.types import pipeline_failure_policy
from google.cloud import storage
from google.protobuf import json_format

Expand Down Expand Up @@ -621,6 +622,99 @@ def test_run_call_pipeline_service_create_with_timeout_not_explicitly_set(
timeout=None,
)

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
)
@pytest.mark.parametrize(
"failure_policy",
[
(
"slow",
pipeline_failure_policy.PipelineFailurePolicy.PIPELINE_FAILURE_POLICY_FAIL_SLOW,
),
(
"fast",
pipeline_failure_policy.PipelineFailurePolicy.PIPELINE_FAILURE_POLICY_FAIL_FAST,
),
],
)
@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_create_with_failure_policy(
sararob marked this conversation as resolved.
Show resolved Hide resolved
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
job_spec,
mock_load_yaml_and_json,
failure_policy,
sync,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
enable_caching=True,
failure_policy=failure_policy[0],
)

job.run(
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
sync=sync,
create_request_timeout=None,
)

if not sync:
job.wait()

expected_runtime_config_dict = {
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
"failurePolicy": failure_policy[1],
}
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(expected_runtime_config_dict, runtime_config)

job_spec = yaml.safe_load(job_spec)
pipeline_spec = job_spec.get("pipelineSpec") or job_spec

# Construct expected request
expected_gapic_pipeline_job = gca_pipeline_job.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
pipeline_spec={
"components": {},
"pipelineInfo": pipeline_spec["pipelineInfo"],
"root": pipeline_spec["root"],
"schemaVersion": "2.1.0",
},
runtime_config=runtime_config,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
)

mock_pipeline_service_create.assert_called_once_with(
parent=_TEST_PARENT,
pipeline_job=expected_gapic_pipeline_job,
pipeline_job_id=_TEST_PIPELINE_JOB_ID,
timeout=None,
)

mock_pipeline_service_get.assert_called_with(
name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY
)

assert job._gca_resource == make_pipeline_job(
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.parametrize(
"job_spec",
[
Expand Down
20 changes: 19 additions & 1 deletion tests/unit/aiplatform/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from google.api_core import client_options, gapic_v1
from google.cloud import aiplatform
from google.cloud.aiplatform import compat, utils
from google.cloud.aiplatform.compat.types import pipeline_failure_policy
from google.cloud.aiplatform.utils import pipeline_utils, tensorboard_utils, yaml_utils
from google.cloud.aiplatform_v1.services.model_service import (
client as model_service_client_v1,
Expand Down Expand Up @@ -454,7 +455,22 @@ def test_pipeline_utils_runtime_config_builder_with_no_op_updates(self):
expected_runtime_config = self.SAMPLE_JOB_SPEC["runtimeConfig"]
assert expected_runtime_config == actual_runtime_config

def test_pipeline_utils_runtime_config_builder_with_merge_updates(self):
@pytest.mark.parametrize(
"failure_policy",
[
(
"slow",
pipeline_failure_policy.PipelineFailurePolicy.PIPELINE_FAILURE_POLICY_FAIL_SLOW,
),
(
"fast",
pipeline_failure_policy.PipelineFailurePolicy.PIPELINE_FAILURE_POLICY_FAIL_FAST,
),
],
)
def test_pipeline_utils_runtime_config_builder_with_merge_updates(
self, failure_policy
):
my_builder = pipeline_utils.PipelineRuntimeConfigBuilder.from_job_spec_json(
self.SAMPLE_JOB_SPEC
)
Expand All @@ -468,6 +484,7 @@ def test_pipeline_utils_runtime_config_builder_with_merge_updates(self):
"bool_param": True,
}
)
my_builder.update_failure_policy(failure_policy[0])
actual_runtime_config = my_builder.build()

expected_runtime_config = {
Expand All @@ -481,6 +498,7 @@ def test_pipeline_utils_runtime_config_builder_with_merge_updates(self):
"list_param": {"stringValue": "[1, 2, 3]"},
"bool_param": {"stringValue": "true"},
},
"failurePolicy": failure_policy[1],
}
assert expected_runtime_config == actual_runtime_config

Expand Down