Skip to content

Commit

Permalink
feat: Add Metadata SDK support and samples for get method (#1516)
Browse files Browse the repository at this point in the history
* feat: Add get() method to Metadata Resource base class.

* Add unit tests for artifact, execution, and context

* Add samples for get execution and get artifact

* fix lint issues

* Fix unit tests.
  • Loading branch information
SinaChavoshi committed Jul 21, 2022
1 parent f93d19c commit d442248
Show file tree
Hide file tree
Showing 9 changed files with 259 additions and 7 deletions.
44 changes: 44 additions & 0 deletions google/cloud/aiplatform/metadata/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,50 @@ def get_or_create(
)
return resource

@classmethod
def get(
cls,
resource_id: str,
metadata_store_id: str = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "_Resource":
"""Retrieves a Metadata resource.
Args:
resource_id (str):
Required. The <resource_id> portion of the resource name with the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>.
metadata_store_id (str):
The <metadata_store_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
If not provided, the MetadataStore's ID will be set to "default".
project (str):
Project used to retrieve or create this resource. Overrides project set in
aiplatform.init.
location (str):
Location used to retrieve or create this resource. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Custom credentials used to retrieve or create this resource. Overrides
credentials set in aiplatform.init.
Returns:
resource (_Resource):
Instantiated representation of the managed Metadata resource or None if no resouce was found.
"""
resource = cls._get(
resource_name=resource_id,
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)
return resource

def sync_resource(self):
"""Syncs local resource with the resource in metadata store."""
self._gca_resource = getattr(self.api_client, self._getter_method)(
Expand Down
14 changes: 14 additions & 0 deletions samples/model-builder/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,13 @@ def mock_get_execution(mock_execution):
yield mock_get_execution


@pytest.fixture
def mock_execution_get(mock_execution):
with patch.object(aiplatform.Execution, "get") as mock_execution_get:
mock_execution_get.return_value = mock_execution
yield mock_execution_get


@pytest.fixture
def mock_create_execution(mock_execution):
with patch.object(aiplatform.Execution, "create") as mock_create_execution:
Expand All @@ -590,6 +597,13 @@ def mock_get_artifact(mock_artifact):
yield mock_get_artifact


@pytest.fixture
def mock_artifact_get(mock_artifact):
with patch.object(aiplatform.Artifact, "get") as mock_artifact_get:
mock_artifact_get.return_value = mock_artifact
yield mock_artifact_get


@pytest.fixture
def mock_pipeline_job_create(mock_pipeline_job):
with patch.object(aiplatform, "PipelineJob") as mock_pipeline_job_create:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

# [START aiplatform_sdk_get_artifact_sample]
def get_artifact_sample(
uri: str,
artifact_id: str,
project: str,
location: str,
):
artifact = aiplatform.Artifact.get_with_uri(
uri=uri, project=project, location=location
artifact = aiplatform.Artifact.get(
resource_id=artifact_id, project=project, location=location
)

return artifact
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
import test_constants


def test_get_artifact_sample(mock_artifact, mock_get_with_uri):
def test_get_artifact_sample(mock_artifact, mock_artifact_get):
artifact = get_artifact_sample.get_artifact_sample(
uri=test_constants.MODEL_ARTIFACT_URI,
artifact_id=test_constants.RESOURCE_ID,
project=test_constants.PROJECT,
location=test_constants.LOCATION,
)

mock_get_with_uri.assert_called_with(
uri=test_constants.MODEL_ARTIFACT_URI,
mock_artifact_get.assert_called_with(
resource_id=test_constants.RESOURCE_ID,
project=test_constants.PROJECT,
location=test_constants.LOCATION,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from google.cloud import aiplatform


# [START aiplatform_sdk_get_artifact_with_uri_sample]
def get_artifact_with_uri_sample(
uri: str,
project: str,
location: str,
):
artifact = aiplatform.Artifact.get_with_uri(
uri=uri, project=project, location=location
)

return artifact


# [END aiplatform_sdk_get_artifact_with_uri_sample]
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import get_artifact_with_uri_sample

import test_constants


def test_get_artifact_with_uri_sample(mock_artifact, mock_get_with_uri):
artifact = get_artifact_with_uri_sample.get_artifact_with_uri_sample(
uri=test_constants.MODEL_ARTIFACT_URI,
project=test_constants.PROJECT,
location=test_constants.LOCATION,
)

mock_get_with_uri.assert_called_with(
uri=test_constants.MODEL_ARTIFACT_URI,
project=test_constants.PROJECT,
location=test_constants.LOCATION,
)

assert artifact is mock_artifact
31 changes: 31 additions & 0 deletions samples/model-builder/experiment_tracking/get_execution_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from google.cloud import aiplatform


# [START aiplatform_sdk_get_execution_sample]
def get_execution_sample(
execution_id: str,
project: str,
location: str,
):
execution = aiplatform.Execution.get(
resource_id=execution_id, project=project, location=location
)

return execution


# [END aiplatform_sdk_get_execution_sample]
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import get_artifact_sample

import test_constants


def test_get_artifact_sample(mock_artifact, mock_artifact_get):
artifact = get_artifact_sample.get_artifact_sample(
artifact_id=test_constants.RESOURCE_ID,
project=test_constants.PROJECT,
location=test_constants.LOCATION,
)

mock_artifact_get.assert_called_with(
resource_id=test_constants.RESOURCE_ID,
project=test_constants.PROJECT,
location=test_constants.LOCATION,
)

assert artifact is mock_artifact
66 changes: 66 additions & 0 deletions tests/unit/aiplatform/test_metadata_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,28 @@ def test_get_or_create_context(
expected_context.name = _TEST_CONTEXT_NAME
assert my_context._gca_resource == expected_context

def test_get_context(self, get_context_mock):
aiplatform.init(project=_TEST_PROJECT)

my_context = context.Context.get(
resource_id=_TEST_CONTEXT_ID,
metadata_store_id=_TEST_METADATA_STORE,
)

expected_context = GapicContext(
schema_title=_TEST_SCHEMA_TITLE,
schema_version=_TEST_SCHEMA_VERSION,
display_name=_TEST_DISPLAY_NAME,
description=_TEST_DESCRIPTION,
metadata=_TEST_METADATA,
)
get_context_mock.assert_called_once_with(
name=_TEST_CONTEXT_NAME, retry=base._DEFAULT_RETRY
)

expected_context.name = _TEST_CONTEXT_NAME
assert my_context._gca_resource == expected_context

@pytest.mark.usefixtures("get_context_mock")
@pytest.mark.usefixtures("create_context_mock")
def test_update_context(self, update_context_mock):
Expand Down Expand Up @@ -633,6 +655,28 @@ def test_get_or_create_execution(
expected_execution.name = _TEST_EXECUTION_NAME
assert my_execution._gca_resource == expected_execution

def test_get_execution(self, get_execution_mock):
aiplatform.init(project=_TEST_PROJECT)

my_execution = execution.Execution.get(
resource_id=_TEST_EXECUTION_ID,
metadata_store_id=_TEST_METADATA_STORE,
)

expected_execution = GapicExecution(
schema_title=_TEST_SCHEMA_TITLE,
schema_version=_TEST_SCHEMA_VERSION,
display_name=_TEST_DISPLAY_NAME,
description=_TEST_DESCRIPTION,
metadata=_TEST_METADATA,
)
get_execution_mock.assert_called_once_with(
name=_TEST_EXECUTION_NAME, retry=base._DEFAULT_RETRY
)

expected_execution.name = _TEST_EXECUTION_NAME
assert my_execution._gca_resource == expected_execution

@pytest.mark.usefixtures("get_execution_mock")
@pytest.mark.usefixtures("create_execution_mock")
def test_update_execution(self, update_execution_mock):
Expand Down Expand Up @@ -883,6 +927,28 @@ def test_get_or_create_artifact(
expected_artifact.name = _TEST_ARTIFACT_NAME
assert my_artifact._gca_resource == expected_artifact

def test_get_artifact(self, get_artifact_mock):
aiplatform.init(project=_TEST_PROJECT)

my_artifact = artifact.Artifact.get(
resource_id=_TEST_ARTIFACT_ID,
metadata_store_id=_TEST_METADATA_STORE,
)

expected_artifact = GapicArtifact(
schema_title=_TEST_SCHEMA_TITLE,
schema_version=_TEST_SCHEMA_VERSION,
display_name=_TEST_DISPLAY_NAME,
description=_TEST_DESCRIPTION,
metadata=_TEST_METADATA,
)
get_artifact_mock.assert_called_once_with(
name=_TEST_ARTIFACT_NAME, retry=base._DEFAULT_RETRY
)

expected_artifact.name = _TEST_ARTIFACT_NAME
assert my_artifact._gca_resource == expected_artifact

@pytest.mark.usefixtures("get_artifact_mock")
@pytest.mark.usefixtures("create_artifact_mock")
def test_update_artifact(self, update_artifact_mock):
Expand Down

0 comments on commit d442248

Please sign in to comment.