Skip to content

Commit

Permalink
feat: add enable_simple_view to PipelineJob.list() (#1614)
Browse files Browse the repository at this point in the history
* feat: add enable_simple_view to PipelineJob.list()

* updates to pipelinejob.list read_mask

* run linter

* update to read_mask

* add placeholder for read_mask to system tests

* unit test fix

* add system test for read_mask filter

* move read mask fields to constants file

* add read_mask docstrings

* remove class name check

Co-authored-by: sasha-gitg <44654632+sasha-gitg@users.noreply.github.com>
  • Loading branch information
sararob and sasha-gitg authored Sep 14, 2022
1 parent a3cc5a3 commit 627fdf9
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 0 deletions.
25 changes: 25 additions & 0 deletions google/cloud/aiplatform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from google.cloud.aiplatform.compat.types import encryption_spec as gca_encryption_spec
from google.cloud.aiplatform.constants import base as base_constants
from google.protobuf import json_format
from google.protobuf import field_mask_pb2 as field_mask

# This is the default retry callback to be used with get methods.
_DEFAULT_RETRY = retry.Retry()
Expand Down Expand Up @@ -1030,6 +1031,7 @@ def _list(
cls_filter: Callable[[proto.Message], bool] = lambda _: True,
filter: Optional[str] = None,
order_by: Optional[str] = None,
read_mask: Optional[field_mask.FieldMask] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
Expand All @@ -1052,6 +1054,14 @@ def _list(
Optional. A comma-separated list of fields to order by, sorted in
ascending order. Use "desc" after a field name for descending.
Supported fields: `display_name`, `create_time`, `update_time`
read_mask (field_mask.FieldMask):
Optional. A FieldMask with a list of strings passed via `paths`
indicating which fields to return for each resource in the response.
For example, passing
field_mask.FieldMask(paths=["create_time", "update_time"])
as `read_mask` would result in each returned VertexAiResourceNoun
in the result list only having the "create_time" and
"update_time" attributes.
project (str):
Optional. Project to retrieve list from. If not set, project
set in aiplatform.init will be used.
Expand All @@ -1067,6 +1077,7 @@ def _list(
Returns:
List[VertexAiResourceNoun] - A list of SDK resource objects
"""

resource = cls._empty_constructor(
project=project, location=location, credentials=credentials
)
Expand All @@ -1083,6 +1094,10 @@ def _list(
),
}

# `read_mask` is only passed from PipelineJob.list() for now
if read_mask is not None:
list_request["read_mask"] = read_mask

if filter:
list_request["filter"] = filter

Expand All @@ -1105,6 +1120,7 @@ def _list_with_local_order(
cls_filter: Callable[[proto.Message], bool] = lambda _: True,
filter: Optional[str] = None,
order_by: Optional[str] = None,
read_mask: Optional[field_mask.FieldMask] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
Expand All @@ -1127,6 +1143,14 @@ def _list_with_local_order(
Optional. A comma-separated list of fields to order by, sorted in
ascending order. Use "desc" after a field name for descending.
Supported fields: `display_name`, `create_time`, `update_time`
read_mask (field_mask.FieldMask):
Optional. A FieldMask with a list of strings passed via `paths`
indicating which fields to return for each resource in the response.
For example, passing
field_mask.FieldMask(paths=["create_time", "update_time"])
as `read_mask` would result in each returned VertexAiResourceNoun
in the result list only having the "create_time" and
"update_time" attributes.
project (str):
Optional. Project to retrieve list from. If not set, project
set in aiplatform.init will be used.
Expand All @@ -1145,6 +1169,7 @@ def _list_with_local_order(
cls_filter=cls_filter,
filter=filter,
order_by=None, # This method will handle the ordering locally
read_mask=read_mask,
project=project,
location=location,
credentials=credentials,
Expand Down
17 changes: 17 additions & 0 deletions google/cloud/aiplatform/constants/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,20 @@

# Pattern for an Artifact Registry URL.
_VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*")

# Fields to include in returned PipelineJob when enable_simple_view=True in PipelineJob.list()
_READ_MASK_FIELDS = [
"name",
"state",
"display_name",
"pipeline_spec.pipeline_info",
"create_time",
"start_time",
"end_time",
"update_time",
"labels",
"template_uri",
"template_metadata.version",
"job_detail.pipeline_run_context",
"job_detail.pipeline_context",
]
24 changes: 24 additions & 0 deletions google/cloud/aiplatform/pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from google.cloud.aiplatform.utils import yaml_utils
from google.cloud.aiplatform.utils import pipeline_utils
from google.protobuf import json_format
from google.protobuf import field_mask_pb2 as field_mask

from google.cloud.aiplatform.compat.types import (
pipeline_job as gca_pipeline_job,
Expand All @@ -56,6 +57,8 @@
# Pattern for an Artifact Registry URL.
_VALID_AR_URL = pipeline_constants._VALID_AR_URL

_READ_MASK_FIELDS = pipeline_constants._READ_MASK_FIELDS


def _get_current_time() -> datetime.datetime:
"""Gets the current timestamp."""
Expand Down Expand Up @@ -509,6 +512,7 @@ def list(
cls,
filter: Optional[str] = None,
order_by: Optional[str] = None,
enable_simple_view: Optional[bool] = False,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
Expand All @@ -530,6 +534,17 @@ def list(
Optional. A comma-separated list of fields to order by, sorted in
ascending order. Use "desc" after a field name for descending.
Supported fields: `display_name`, `create_time`, `update_time`
enable_simple_view (bool):
Optional. Whether to pass the `read_mask` parameter to the list call.
This will improve the performance of calling list(). However, the
returned PipelineJob list will not include all fields for each PipelineJob.
Setting this to True will exclude the following fields in your response:
`runtime_config`, `service_account`, `network`, and some subfields of
`pipeline_spec` and `job_detail`. The following fields will be included in
each PipelineJob resource in your response: `state`, `display_name`,
`pipeline_spec.pipeline_info`, `create_time`, `start_time`, `end_time`,
`update_time`, `labels`, `template_uri`, `template_metadata.version`,
`job_detail.pipeline_run_context`, `job_detail.pipeline_context`.
project (str):
Optional. Project to retrieve list from. If not set, project
set in aiplatform.init will be used.
Expand All @@ -544,9 +559,18 @@ def list(
List[PipelineJob] - A list of PipelineJob resource objects
"""

read_mask_fields = None

if enable_simple_view:
read_mask_fields = field_mask.FieldMask(paths=_READ_MASK_FIELDS)
_LOGGER.warn(
"By enabling simple view, the PipelineJob resources returned from this method will not contain all fields."
)

return cls._list_with_local_order(
filter=filter,
order_by=order_by,
read_mask=read_mask_fields,
project=project,
location=location,
credentials=credentials,
Expand Down
13 changes: 13 additions & 0 deletions tests/system/aiplatform/test_pipeline_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from google.cloud import aiplatform
from tests.system.aiplatform import e2e_base

from google.protobuf.json_format import MessageToDict


@pytest.mark.usefixtures("tear_down_resources")
class TestPipelineJob(e2e_base.TestEndToEnd):
Expand Down Expand Up @@ -59,3 +61,14 @@ def training_pipeline(number_of_epochs: int = 10):
shared_state.setdefault("resources", []).append(job)

job.wait()

list_with_read_mask = aiplatform.PipelineJob.list(enable_simple_view=True)
list_without_read_mask = aiplatform.PipelineJob.list()

# enable_simple_view=True should apply the `read_mask` filter to limit PipelineJob fields returned
assert "serviceAccount" in MessageToDict(
list_without_read_mask[0].gca_resource._pb
)
assert "serviceAccount" not in MessageToDict(
list_with_read_mask[0].gca_resource._pb
)
57 changes: 57 additions & 0 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from google.cloud import aiplatform
from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform.constants import pipeline as pipeline_constants
from google.cloud.aiplatform_v1 import Context as GapicContext
from google.cloud.aiplatform_v1 import MetadataStore as GapicMetadataStore
from google.cloud.aiplatform.metadata import constants
Expand All @@ -37,6 +38,7 @@
from google.cloud.aiplatform.utils import gcs_utils
from google.cloud import storage
from google.protobuf import json_format
from google.protobuf import field_mask_pb2 as field_mask

from google.cloud.aiplatform.compat.services import (
pipeline_service_client,
Expand All @@ -62,6 +64,9 @@
_TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_PIPELINE_JOB_ID}"

_TEST_PIPELINE_JOB_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/pipelineJobs/{_TEST_PIPELINE_JOB_ID}"
_TEST_PIPELINE_JOB_LIST_READ_MASK = field_mask.FieldMask(
paths=pipeline_constants._READ_MASK_FIELDS
)

_TEST_PIPELINE_PARAMETER_VALUES_LEGACY = {"string_param": "hello"}
_TEST_PIPELINE_PARAMETER_VALUES = {
Expand Down Expand Up @@ -332,6 +337,17 @@ def mock_pipeline_service_list():
with mock.patch.object(
pipeline_service_client.PipelineServiceClient, "list_pipeline_jobs"
) as mock_list_pipeline_jobs:
mock_list_pipeline_jobs.return_value = [
make_pipeline_job(
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
),
make_pipeline_job(
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
),
make_pipeline_job(
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
),
]
yield mock_list_pipeline_jobs


Expand Down Expand Up @@ -1354,6 +1370,47 @@ def test_list_pipeline_job(
request={"parent": _TEST_PARENT}
)

@pytest.mark.usefixtures(
"mock_pipeline_service_create",
"mock_pipeline_service_get",
"mock_pipeline_bucket_exists",
)
@pytest.mark.parametrize(
"job_spec",
[
_TEST_PIPELINE_SPEC_JSON,
_TEST_PIPELINE_SPEC_YAML,
_TEST_PIPELINE_JOB,
_TEST_PIPELINE_SPEC_LEGACY_JSON,
_TEST_PIPELINE_SPEC_LEGACY_YAML,
_TEST_PIPELINE_JOB_LEGACY,
],
)
def test_list_pipeline_job_with_read_mask(
self, mock_pipeline_service_list, mock_load_yaml_and_json
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
)

job.run()
job.list(enable_simple_view=True)

mock_pipeline_service_list.assert_called_once_with(
request={
"parent": _TEST_PARENT,
"read_mask": _TEST_PIPELINE_JOB_LIST_READ_MASK,
},
)

@pytest.mark.usefixtures(
"mock_pipeline_service_create",
"mock_pipeline_service_get",
Expand Down

0 comments on commit 627fdf9

Please sign in to comment.