From 03995777864fe6c14b58963d87780fb9ea0ad719 Mon Sep 17 00:00:00 2001 From: jaycee-li Date: Wed, 8 Jun 2022 14:07:20 -0700 Subject: [PATCH 1/3] feat: add update() method and system test --- google/cloud/aiplatform/datasets/dataset.py | 67 ++++++++++++++++++++- tests/system/aiplatform/test_dataset.py | 28 ++++++++- 2 files changed, 91 insertions(+), 4 deletions(-) diff --git a/google/cloud/aiplatform/datasets/dataset.py b/google/cloud/aiplatform/datasets/dataset.py index 97651adefb..9bb0921fd2 100644 --- a/google/cloud/aiplatform/datasets/dataset.py +++ b/google/cloud/aiplatform/datasets/dataset.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2020 Google LLC +# Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,6 +31,7 @@ io as gca_io, ) from google.cloud.aiplatform.datasets import _datasources +from google.protobuf import field_mask_pb2 _LOGGER = base.Logger(__name__) @@ -597,8 +598,68 @@ def export_data(self, output_dir: str) -> Sequence[str]: return export_data_response.exported_files - def update(self): - raise NotImplementedError("Update dataset has not been implemented yet") + def update( + self, + display_name: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + description: Optional[str] = None, + update_request_timeout: Optional[float] = None, + ) -> "_Dataset": + """Update the dataset. + Updatable fields: + - ``display_name`` + - ``description`` + - ``labels`` + + Args: + display_name (str): + Optional. The user-defined name of the Dataset. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + labels (Dict[str, str]): + Optional. Labels with user-defined metadata to organize your Tensorboards. + Label keys and values can be no longer than 64 characters + (Unicode codepoints), can only contain lowercase letters, numeric + characters, underscores and dashes. International characters are allowed. + No more than 64 user labels can be associated with one Tensorboard + (System labels are excluded). + See https://goo.gl/xmQnxf for more information and examples of labels. + System reserved label keys are prefixed with "aiplatform.googleapis.com/" + and are immutable. + description (str): + Optional. The description of the Dataset. + update_request_timeout (float): + Optional. The timeout for the update request in seconds. + + Returns: + dataset (Dataset): + Updated dataset. + """ + + update_mask = field_mask_pb2.FieldMask() + if display_name: + update_mask.paths.append("display_name") + + if labels: + update_mask.paths.append("labels") + + if description: + update_mask.paths.append("description") + + update_dataset = gca_dataset.Dataset( + name=self.resource_name, + display_name=display_name, + description=description, + labels=labels, + ) + + self.api_client.update_dataset( + dataset=update_dataset, + update_mask=update_mask, + timeout=update_request_timeout, + ) + + return self.__class__(dataset_name=self.resource_name) @classmethod def list( diff --git a/tests/system/aiplatform/test_dataset.py b/tests/system/aiplatform/test_dataset.py index 7cd3c0416c..bf8d760381 100644 --- a/tests/system/aiplatform/test_dataset.py +++ b/tests/system/aiplatform/test_dataset.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2020 Google LLC +# Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -50,6 +50,8 @@ "6203215905493614592" # permanent_text_entity_extraction_dataset ) _TEST_DATASET_DISPLAY_NAME = "permanent_50_flowers_dataset" +_TEST_DATASET_LABELS = {"test": "labels"} +_TEST_DATASET_DESCRIPTION = "test description" _TEST_TABULAR_CLASSIFICATION_GCS_SOURCE = "gs://ucaip-sample-resources/iris_1000.csv" _TEST_FORECASTING_BQ_SOURCE = ( "bq://ucaip-sample-tests:ucaip_test_us_central1.2020_sales_train" @@ -350,3 +352,27 @@ def test_export_data(self, storage_client, staging_bucket): blob = bucket.get_blob(prefix) assert blob # Verify the returned GCS export path exists + + def test_update_dataset(self): + """Create a new dataset and use update() method to change its display_name, labels, and description. + Then confirm these fields of the dataset was successfully modifed.""" + + try: + dataset = aiplatform.ImageDataset.create() + labels = dataset.labels + + dataset = dataset.update( + display_name=_TEST_DATASET_DISPLAY_NAME, + labels=_TEST_DATASET_LABELS, + description=_TEST_DATASET_DESCRIPTION, + update_request_timeout=None, + ) + labels.update(_TEST_DATASET_LABELS) + + assert dataset.display_name == _TEST_DATASET_DISPLAY_NAME + assert dataset.labels == labels + assert dataset.gca_resource.description == _TEST_DATASET_DESCRIPTION + + finally: + if dataset is not None: + dataset.delete() From d2de2e05c3605eeeb8852c022df31051d960436a Mon Sep 17 00:00:00 2001 From: jaycee-li Date: Wed, 8 Jun 2022 15:20:03 -0700 Subject: [PATCH 2/3] fix: fix and add unit test --- google/cloud/aiplatform/datasets/dataset.py | 5 ++- tests/unit/aiplatform/test_datasets.py | 46 +++++++++++++++++++++ 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/google/cloud/aiplatform/datasets/dataset.py b/google/cloud/aiplatform/datasets/dataset.py index 9bb0921fd2..508932779b 100644 --- a/google/cloud/aiplatform/datasets/dataset.py +++ b/google/cloud/aiplatform/datasets/dataset.py @@ -600,6 +600,7 @@ def export_data(self, output_dir: str) -> Sequence[str]: def update( self, + *, display_name: Optional[str] = None, labels: Optional[Dict[str, str]] = None, description: Optional[str] = None, @@ -653,13 +654,13 @@ def update( labels=labels, ) - self.api_client.update_dataset( + self._gca_resource = self.api_client.update_dataset( dataset=update_dataset, update_mask=update_mask, timeout=update_request_timeout, ) - return self.__class__(dataset_name=self.resource_name) + return self @classmethod def list( diff --git a/tests/unit/aiplatform/test_datasets.py b/tests/unit/aiplatform/test_datasets.py index 0624264e4c..2a912d47dd 100644 --- a/tests/unit/aiplatform/test_datasets.py +++ b/tests/unit/aiplatform/test_datasets.py @@ -36,6 +36,7 @@ from google.cloud.aiplatform import schema from google.cloud import bigquery from google.cloud import storage +from google.protobuf import field_mask_pb2 from google.cloud.aiplatform.compat.services import dataset_service_client @@ -59,6 +60,7 @@ _TEST_ID = "1028944691210842416" _TEST_DISPLAY_NAME = "my_dataset_1234" _TEST_DATA_LABEL_ITEMS = None +_TEST_DESCRIPTION = "test description" _TEST_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/datasets/{_TEST_ID}" _TEST_ALT_NAME = ( @@ -425,6 +427,20 @@ def export_data_mock(): yield export_data_mock +@pytest.fixture +def update_dataset_mock(): + with patch.object( + dataset_service_client.DatasetServiceClient, "update_dataset" + ) as update_dataset_mock: + update_dataset_mock.return_value = gca_dataset.Dataset( + name=_TEST_NAME, + display_name=f"update_{_TEST_DISPLAY_NAME}", + labels=_TEST_LABELS, + description=_TEST_DESCRIPTION, + ) + yield update_dataset_mock + + @pytest.fixture def list_datasets_mock(): with patch.object( @@ -996,6 +1012,36 @@ def test_delete_dataset(self, delete_dataset_mock, sync): delete_dataset_mock.assert_called_once_with(name=my_dataset.resource_name) + @pytest.mark.usefixtures("get_dataset_mock") + def test_update_dataset(self, update_dataset_mock): + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets._Dataset(dataset_name=_TEST_NAME) + + my_dataset = my_dataset.update( + display_name=f"update_{_TEST_DISPLAY_NAME}", + labels=_TEST_LABELS, + description=_TEST_DESCRIPTION, + update_request_timeout=None, + ) + + expected_dataset = gca_dataset.Dataset( + name=_TEST_NAME, + display_name=f"update_{_TEST_DISPLAY_NAME}", + labels=_TEST_LABELS, + description=_TEST_DESCRIPTION, + ) + + expected_mask = field_mask_pb2.FieldMask( + paths=["display_name", "labels", "description"] + ) + + update_dataset_mock.assert_called_once_with( + dataset=expected_dataset, + update_mask=expected_mask, + timeout=None, + ) + @pytest.mark.usefixtures("google_auth_mock") class TestImageDataset: From fcee028cab2c19b387a7bc48ed6a7bb2e8f2203a Mon Sep 17 00:00:00 2001 From: jaycee-li Date: Thu, 9 Jun 2022 10:59:24 -0700 Subject: [PATCH 3/3] remove superfluous line in system test --- tests/system/aiplatform/test_dataset.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/system/aiplatform/test_dataset.py b/tests/system/aiplatform/test_dataset.py index bf8d760381..40f2e87c46 100644 --- a/tests/system/aiplatform/test_dataset.py +++ b/tests/system/aiplatform/test_dataset.py @@ -374,5 +374,4 @@ def test_update_dataset(self): assert dataset.gca_resource.description == _TEST_DATASET_DESCRIPTION finally: - if dataset is not None: - dataset.delete() + dataset.delete()