Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support dataset update #1416

Merged
merged 4 commits into from
Jun 9, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 65 additions & 3 deletions google/cloud/aiplatform/datasets/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

# Copyright 2020 Google LLC
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,6 +31,7 @@
io as gca_io,
)
from google.cloud.aiplatform.datasets import _datasources
from google.protobuf import field_mask_pb2

_LOGGER = base.Logger(__name__)

Expand Down Expand Up @@ -597,8 +598,69 @@ def export_data(self, output_dir: str) -> Sequence[str]:

return export_data_response.exported_files

def update(self):
raise NotImplementedError("Update dataset has not been implemented yet")
def update(
self,
jaycee-li marked this conversation as resolved.
Show resolved Hide resolved
*,
display_name: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
description: Optional[str] = None,
update_request_timeout: Optional[float] = None,
) -> "_Dataset":
"""Update the dataset.
Updatable fields:
- ``display_name``
- ``description``
- ``labels``

Args:
display_name (str):
Optional. The user-defined name of the Dataset.
The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
labels (Dict[str, str]):
Optional. Labels with user-defined metadata to organize your Tensorboards.
Label keys and values can be no longer than 64 characters
(Unicode codepoints), can only contain lowercase letters, numeric
characters, underscores and dashes. International characters are allowed.
No more than 64 user labels can be associated with one Tensorboard
(System labels are excluded).
See https://goo.gl/xmQnxf for more information and examples of labels.
System reserved label keys are prefixed with "aiplatform.googleapis.com/"
and are immutable.
description (str):
Optional. The description of the Dataset.
update_request_timeout (float):
Optional. The timeout for the update request in seconds.

Returns:
dataset (Dataset):
Updated dataset.
"""

update_mask = field_mask_pb2.FieldMask()
if display_name:
update_mask.paths.append("display_name")

if labels:
update_mask.paths.append("labels")

if description:
update_mask.paths.append("description")

update_dataset = gca_dataset.Dataset(
name=self.resource_name,
display_name=display_name,
description=description,
labels=labels,
)

self._gca_resource = self.api_client.update_dataset(
dataset=update_dataset,
update_mask=update_mask,
timeout=update_request_timeout,
)

return self

@classmethod
def list(
Expand Down
28 changes: 27 additions & 1 deletion tests/system/aiplatform/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

# Copyright 2020 Google LLC
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -50,6 +50,8 @@
"6203215905493614592" # permanent_text_entity_extraction_dataset
)
_TEST_DATASET_DISPLAY_NAME = "permanent_50_flowers_dataset"
_TEST_DATASET_LABELS = {"test": "labels"}
_TEST_DATASET_DESCRIPTION = "test description"
_TEST_TABULAR_CLASSIFICATION_GCS_SOURCE = "gs://ucaip-sample-resources/iris_1000.csv"
_TEST_FORECASTING_BQ_SOURCE = (
"bq://ucaip-sample-tests:ucaip_test_us_central1.2020_sales_train"
Expand Down Expand Up @@ -350,3 +352,27 @@ def test_export_data(self, storage_client, staging_bucket):
blob = bucket.get_blob(prefix)

assert blob # Verify the returned GCS export path exists

def test_update_dataset(self):
jaycee-li marked this conversation as resolved.
Show resolved Hide resolved
"""Create a new dataset and use update() method to change its display_name, labels, and description.
Then confirm these fields of the dataset was successfully modifed."""

try:
dataset = aiplatform.ImageDataset.create()
labels = dataset.labels

dataset = dataset.update(
display_name=_TEST_DATASET_DISPLAY_NAME,
labels=_TEST_DATASET_LABELS,
description=_TEST_DATASET_DESCRIPTION,
update_request_timeout=None,
)
labels.update(_TEST_DATASET_LABELS)

assert dataset.display_name == _TEST_DATASET_DISPLAY_NAME
assert dataset.labels == labels
assert dataset.gca_resource.description == _TEST_DATASET_DESCRIPTION

finally:
if dataset is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is superfluous. dataset will only exist on assignment after ImageDataset.create is successful.

The test should start with dataset=None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the ImageDataset.create is successful, then we want to delete the dataset finally. But if create() fails, dataset is None and we shouldn't call dataset.delete(). So I think it's necessary to add this line.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If create fails then dataset is not defined and if dataset is not None will throw a NameError. dataset isn't created until after the right side of the statement, dataset = aiplatform.ImageDataset.create() is successful. You can validate this behavior in Python:

def f():
    raise RuntimeError('')
    
try:
    dataset = f()
except:
    if dataset is None:
        pass

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it!

But I see other system tests like test_create_and_import_image_dataset and test_create_tabular_dataset have the same pattern that checks if dataset is not None. Is there any reason for that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, those should be eventually fixed as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Just removed this line. Thank you!

dataset.delete()
46 changes: 46 additions & 0 deletions tests/unit/aiplatform/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from google.cloud.aiplatform import schema
from google.cloud import bigquery
from google.cloud import storage
from google.protobuf import field_mask_pb2

from google.cloud.aiplatform.compat.services import dataset_service_client

Expand All @@ -59,6 +60,7 @@
_TEST_ID = "1028944691210842416"
_TEST_DISPLAY_NAME = "my_dataset_1234"
_TEST_DATA_LABEL_ITEMS = None
_TEST_DESCRIPTION = "test description"

_TEST_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/datasets/{_TEST_ID}"
_TEST_ALT_NAME = (
Expand Down Expand Up @@ -425,6 +427,20 @@ def export_data_mock():
yield export_data_mock


@pytest.fixture
def update_dataset_mock():
with patch.object(
dataset_service_client.DatasetServiceClient, "update_dataset"
) as update_dataset_mock:
update_dataset_mock.return_value = gca_dataset.Dataset(
name=_TEST_NAME,
display_name=f"update_{_TEST_DISPLAY_NAME}",
labels=_TEST_LABELS,
description=_TEST_DESCRIPTION,
)
yield update_dataset_mock


@pytest.fixture
def list_datasets_mock():
with patch.object(
Expand Down Expand Up @@ -996,6 +1012,36 @@ def test_delete_dataset(self, delete_dataset_mock, sync):

delete_dataset_mock.assert_called_once_with(name=my_dataset.resource_name)

@pytest.mark.usefixtures("get_dataset_mock")
def test_update_dataset(self, update_dataset_mock):
aiplatform.init(project=_TEST_PROJECT)

my_dataset = datasets._Dataset(dataset_name=_TEST_NAME)

my_dataset = my_dataset.update(
display_name=f"update_{_TEST_DISPLAY_NAME}",
labels=_TEST_LABELS,
description=_TEST_DESCRIPTION,
update_request_timeout=None,
)

expected_dataset = gca_dataset.Dataset(
name=_TEST_NAME,
display_name=f"update_{_TEST_DISPLAY_NAME}",
labels=_TEST_LABELS,
description=_TEST_DESCRIPTION,
)

expected_mask = field_mask_pb2.FieldMask(
paths=["display_name", "labels", "description"]
)

update_dataset_mock.assert_called_once_with(
dataset=expected_dataset,
update_mask=expected_mask,
timeout=None,
)


@pytest.mark.usefixtures("google_auth_mock")
class TestImageDataset:
Expand Down