Skip to content

Commit

Permalink
Fix validating parent_model parameter in UploadModelOperator (#43473
Browse files Browse the repository at this point in the history
)

Co-authored-by: Ulada Zakharava <vlada_zakharava@epam.com>
  • Loading branch information
VladaZakharova and Ulada Zakharava authored Oct 29, 2024
1 parent eda6a8f commit b344cc1
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def upload_model(
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param region: Required. The ID of the Google Cloud region that the service belongs to.
:param model: Required. The Model to create.
:param parent_model: The name of the parent model to create a new version under.
:param parent_model: The ID of the parent model to create a new version under.
:param retry: Designation of what errors, if any, should be retried.
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
Expand All @@ -233,7 +233,7 @@ def upload_model(
}

if parent_model:
request["parent_model"] = parent_model
request["parent_model"] = client.model_path(project_id, region, parent_model)

result = client.upload_model(
request=request,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,9 @@ class UploadModelOperator(GoogleCloudBaseOperator):
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param region: Required. The ID of the Google Cloud region that the service belongs to.
:param model: Required. The Model to create.
:param parent_model: The name of the parent model to create a new version under.
:param model: Required. The Model to create. Creating model with the name that already
exists leads to creating new version of existing model.
:param parent_model: The ID of the parent model to create a new version under.
:param retry: Designation of what errors, if any, should be retried.
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
Expand All @@ -377,7 +378,7 @@ class UploadModelOperator(GoogleCloudBaseOperator):
account from the list granting this role to the originating account (templated).
"""

template_fields = ("region", "project_id", "model", "impersonation_chain")
template_fields = ("region", "project_id", "model", "parent_model", "impersonation_chain")
operator_extra_links = (VertexAIModelLink(),)

def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
TEST_REGION: str = "test-region"
TEST_PROJECT_ID: str = "test-project-id"
TEST_MODEL = None
TEST_PARENT_MODEL = "test-parent-model"
TEST_PARENT_MODEL = "projects/test-project-id/locations/test-region/models/test-parent-model"
TEST_MODEL_NAME: str = "test_model_name"
TEST_OUTPUT_CONFIG: dict = {}

Expand Down Expand Up @@ -148,7 +148,7 @@ def test_upload_model_with_parent_model(self, mock_client) -> None:
request=dict(
parent=mock_client.return_value.common_location_path.return_value,
model=TEST_MODEL,
parent_model=TEST_PARENT_MODEL,
parent_model=mock_client.return_value.model_path.return_value,
),
metadata=(),
retry=DEFAULT,
Expand Down Expand Up @@ -352,7 +352,7 @@ def test_upload_model_with_parent_model(self, mock_client) -> None:
request=dict(
parent=mock_client.return_value.common_location_path.return_value,
model=TEST_MODEL,
parent_model=TEST_PARENT_MODEL,
parent_model=mock_client.return_value.model_path.return_value,
),
metadata=(),
retry=DEFAULT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,19 @@
"health_route": "",
},
}
MODEL_OBJ_V2 = {
"display_name": f"model-{ENV_ID}-v2",
"artifact_uri": "{{ti.xcom_pull('custom_task')['artifactUri']}}",
"container_spec": {
"image_uri": MODEL_SERVING_CONTAINER_URI,
"command": [],
"args": [],
"env": [],
"ports": [],
"predict_route": "",
"health_route": "",
},
}


with DAG(
Expand Down Expand Up @@ -229,13 +242,14 @@
project_id=PROJECT_ID,
model=MODEL_OBJ,
)
upload_model_v1 = upload_model.output["model_id"]
# [END how_to_cloud_vertex_ai_upload_model_operator]
upload_model_with_parent_model = UploadModelOperator(
task_id="upload_model_with_parent_model",
region=REGION,
project_id=PROJECT_ID,
model=MODEL_OBJ,
parent_model=MODEL_DISPLAY_NAME,
model=MODEL_OBJ_V2,
parent_model=upload_model_v1,
)

# [START how_to_cloud_vertex_ai_export_model_operator]
Expand Down

0 comments on commit b344cc1

Please sign in to comment.