Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Metadata SDK support and samples for get method #1516

Merged
merged 10 commits into from
Jul 21, 2022
44 changes: 44 additions & 0 deletions google/cloud/aiplatform/metadata/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,50 @@ def get_or_create(
)
return resource

@classmethod
def get(
cls,
resource_id: str,
metadata_store_id: str = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "_Resource":
"""Retrieves a Metadata resource.

Args:
resource_id (str):
Required. The <resource_id> portion of the resource name with the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>.
metadata_store_id (str):
The <metadata_store_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
If not provided, the MetadataStore's ID will be set to "default".
project (str):
Project used to retrieve or create this resource. Overrides project set in
aiplatform.init.
location (str):
Location used to retrieve or create this resource. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Custom credentials used to retrieve or create this resource. Overrides
credentials set in aiplatform.init.

Returns:
resource (_Resource):
Instantiated representation of the managed Metadata resource or None if no resouce was found.

"""
resource = cls._get(
resource_name=resource_id,
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)
return resource

def sync_resource(self):
"""Syncs local resource with the resource in metadata store."""
self._gca_resource = getattr(self.api_client, self._getter_method)(
Expand Down
14 changes: 14 additions & 0 deletions samples/model-builder/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,13 @@ def mock_get_execution(mock_execution):
yield mock_get_execution


@pytest.fixture
def mock_execution_get(mock_execution):
with patch.object(aiplatform.Execution, "get") as mock_execution_get:
mock_execution_get.return_value = mock_execution
yield mock_execution_get


@pytest.fixture
def mock_create_execution(mock_execution):
with patch.object(aiplatform.Execution, "create") as mock_create_execution:
Expand All @@ -590,6 +597,13 @@ def mock_get_artifact(mock_artifact):
yield mock_get_artifact


@pytest.fixture
def mock_artifact_get(mock_artifact):
with patch.object(aiplatform.Artifact, "get") as mock_artifact_get:
mock_artifact_get.return_value = mock_artifact
yield mock_artifact_get


@pytest.fixture
def mock_pipeline_job_create(mock_pipeline_job):
with patch.object(aiplatform, "PipelineJob") as mock_pipeline_job_create:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

# [START aiplatform_sdk_get_artifact_sample]
def get_artifact_sample(
uri: str,
artifact_id: str,
project: str,
location: str,
):
artifact = aiplatform.Artifact.get_with_uri(
uri=uri, project=project, location=location
artifact = aiplatform.Artifact.get(
resource_id=artifact_id, project=project, location=location
)

return artifact
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
import test_constants


def test_get_artifact_sample(mock_artifact, mock_get_with_uri):
def test_get_artifact_sample(mock_artifact, mock_artifact_get):
artifact = get_artifact_sample.get_artifact_sample(
uri=test_constants.MODEL_ARTIFACT_URI,
artifact_id=test_constants.RESOURCE_ID,
project=test_constants.PROJECT,
location=test_constants.LOCATION,
)

mock_get_with_uri.assert_called_with(
uri=test_constants.MODEL_ARTIFACT_URI,
mock_artifact_get.assert_called_with(
resource_id=test_constants.RESOURCE_ID,
project=test_constants.PROJECT,
location=test_constants.LOCATION,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from google.cloud import aiplatform


# [START aiplatform_sdk_get_artifact_with_uri_sample]
def get_artifact_with_uri_sample(
uri: str,
project: str,
location: str,
):
artifact = aiplatform.Artifact.get_with_uri(
uri=uri, project=project, location=location
)

return artifact


# [END aiplatform_sdk_get_artifact_with_uri_sample]
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import get_artifact_with_uri_sample

import test_constants


def test_get_artifact_with_uri_sample(mock_artifact, mock_get_with_uri):
artifact = get_artifact_with_uri_sample.get_artifact_with_uri_sample(
uri=test_constants.MODEL_ARTIFACT_URI,
project=test_constants.PROJECT,
location=test_constants.LOCATION,
)

mock_get_with_uri.assert_called_with(
uri=test_constants.MODEL_ARTIFACT_URI,
project=test_constants.PROJECT,
location=test_constants.LOCATION,
)

assert artifact is mock_artifact
31 changes: 31 additions & 0 deletions samples/model-builder/experiment_tracking/get_execution_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from google.cloud import aiplatform


# [START aiplatform_sdk_get_execution_sample]
def get_execution_sample(
execution_id: str,
project: str,
location: str,
):
execution = aiplatform.Execution.get(
resource_id=execution_id, project=project, location=location
)

return execution


# [END aiplatform_sdk_get_execution_sample]
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import get_artifact_sample

import test_constants


def test_get_artifact_sample(mock_artifact, mock_artifact_get):
artifact = get_artifact_sample.get_artifact_sample(
artifact_id=test_constants.RESOURCE_ID,
project=test_constants.PROJECT,
location=test_constants.LOCATION,
)

mock_artifact_get.assert_called_with(
resource_id=test_constants.RESOURCE_ID,
project=test_constants.PROJECT,
location=test_constants.LOCATION,
)

assert artifact is mock_artifact
66 changes: 66 additions & 0 deletions tests/unit/aiplatform/test_metadata_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,28 @@ def test_get_or_create_context(
expected_context.name = _TEST_CONTEXT_NAME
assert my_context._gca_resource == expected_context

def test_get_context(self, get_context_mock):
aiplatform.init(project=_TEST_PROJECT)

my_context = context.Context.get(
resource_id=_TEST_CONTEXT_ID,
metadata_store_id=_TEST_METADATA_STORE,
)

expected_context = GapicContext(
schema_title=_TEST_SCHEMA_TITLE,
schema_version=_TEST_SCHEMA_VERSION,
display_name=_TEST_DISPLAY_NAME,
description=_TEST_DESCRIPTION,
metadata=_TEST_METADATA,
)
get_context_mock.assert_called_once_with(
name=_TEST_CONTEXT_NAME, retry=base._DEFAULT_RETRY
)

expected_context.name = _TEST_CONTEXT_NAME
assert my_context._gca_resource == expected_context

@pytest.mark.usefixtures("get_context_mock")
@pytest.mark.usefixtures("create_context_mock")
def test_update_context(self, update_context_mock):
Expand Down Expand Up @@ -633,6 +655,28 @@ def test_get_or_create_execution(
expected_execution.name = _TEST_EXECUTION_NAME
assert my_execution._gca_resource == expected_execution

def test_get_execution(self, get_execution_mock):
aiplatform.init(project=_TEST_PROJECT)

my_execution = execution.Execution.get(
resource_id=_TEST_EXECUTION_ID,
metadata_store_id=_TEST_METADATA_STORE,
)

expected_execution = GapicExecution(
schema_title=_TEST_SCHEMA_TITLE,
schema_version=_TEST_SCHEMA_VERSION,
display_name=_TEST_DISPLAY_NAME,
description=_TEST_DESCRIPTION,
metadata=_TEST_METADATA,
)
get_execution_mock.assert_called_once_with(
name=_TEST_EXECUTION_NAME, retry=base._DEFAULT_RETRY
)

expected_execution.name = _TEST_EXECUTION_NAME
assert my_execution._gca_resource == expected_execution

@pytest.mark.usefixtures("get_execution_mock")
@pytest.mark.usefixtures("create_execution_mock")
def test_update_execution(self, update_execution_mock):
Expand Down Expand Up @@ -883,6 +927,28 @@ def test_get_or_create_artifact(
expected_artifact.name = _TEST_ARTIFACT_NAME
assert my_artifact._gca_resource == expected_artifact

def test_get_artifact(self, get_artifact_mock):
aiplatform.init(project=_TEST_PROJECT)

my_artifact = artifact.Artifact.get(
resource_id=_TEST_ARTIFACT_ID,
metadata_store_id=_TEST_METADATA_STORE,
)

expected_artifact = GapicArtifact(
schema_title=_TEST_SCHEMA_TITLE,
schema_version=_TEST_SCHEMA_VERSION,
display_name=_TEST_DISPLAY_NAME,
description=_TEST_DESCRIPTION,
metadata=_TEST_METADATA,
)
get_artifact_mock.assert_called_once_with(
name=_TEST_ARTIFACT_NAME, retry=base._DEFAULT_RETRY
)

expected_artifact.name = _TEST_ARTIFACT_NAME
assert my_artifact._gca_resource == expected_artifact

@pytest.mark.usefixtures("get_artifact_mock")
@pytest.mark.usefixtures("create_artifact_mock")
def test_update_artifact(self, update_artifact_mock):
Expand Down