Skip to content

Commit

Permalink
feat: add support for return public endpoint dns name in matching engine
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 525507137
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Apr 19, 2023
1 parent 4b0722c commit 1b5ae44
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,12 @@ def _create(

return index_obj

@property
def public_endpoint_domain_name(self) -> Optional[str]:
"""Public endpoint DNS name."""
self._assert_gca_resource_is_available()
return self._gca_resource.public_endpoint_domain_name

def update(
self,
display_name: str,
Expand Down
65 changes: 64 additions & 1 deletion tests/unit/aiplatform/test_matching_engine_index_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@

# index_endpoint
_TEST_INDEX_ENDPOINT_ID = "index_endpoint_id"
_TEST_INDEX_ENDPOINT_PUBLIC_DNS = (
"1114627793.us-central1-249381615684.vdb.vertexai.goog"
)
_TEST_INDEX_ENDPOINT_NAME = f"{_TEST_PARENT}/indexEndpoints/{_TEST_INDEX_ENDPOINT_ID}"
_TEST_INDEX_ENDPOINT_DISPLAY_NAME = "index_endpoint_display_name"
_TEST_INDEX_ENDPOINT_DESCRIPTION = "index_endpoint_description"
Expand Down Expand Up @@ -308,6 +311,57 @@ def get_index_endpoint_mock():
yield get_index_endpoint_mock


@pytest.fixture
def get_index_public_endpoint_mock():
with patch.object(
index_endpoint_service_client.IndexEndpointServiceClient, "get_index_endpoint"
) as get_index_public_endpoint_mock:
index_endpoint = gca_index_endpoint.IndexEndpoint(
name=_TEST_INDEX_ENDPOINT_NAME,
display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME,
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
public_endpoint_domain_name=_TEST_INDEX_ENDPOINT_PUBLIC_DNS,
)
index_endpoint.deployed_indexes = [
gca_index_endpoint.DeployedIndex(
id=_TEST_DEPLOYED_INDEX_ID,
index=_TEST_INDEX_NAME,
display_name=_TEST_DEPLOYED_INDEX_DISPLAY_NAME,
enable_access_logging=_TEST_ENABLE_ACCESS_LOGGING,
deployment_group=_TEST_DEPLOYMENT_GROUP,
automatic_resources={
"min_replica_count": _TEST_MIN_REPLICA_COUNT,
"max_replica_count": _TEST_MAX_REPLICA_COUNT,
},
deployed_index_auth_config=gca_index_endpoint.DeployedIndexAuthConfig(
auth_provider=gca_index_endpoint.DeployedIndexAuthConfig.AuthProvider(
audiences=_TEST_AUTH_CONFIG_AUDIENCES,
allowed_issuers=_TEST_AUTH_CONFIG_ALLOWED_ISSUERS,
)
),
),
gca_index_endpoint.DeployedIndex(
id=f"{_TEST_DEPLOYED_INDEX_ID}_2",
index=f"{_TEST_INDEX_NAME}_2",
display_name=_TEST_DEPLOYED_INDEX_DISPLAY_NAME,
enable_access_logging=_TEST_ENABLE_ACCESS_LOGGING,
deployment_group=_TEST_DEPLOYMENT_GROUP,
automatic_resources={
"min_replica_count": _TEST_MIN_REPLICA_COUNT,
"max_replica_count": _TEST_MAX_REPLICA_COUNT,
},
deployed_index_auth_config=gca_index_endpoint.DeployedIndexAuthConfig(
auth_provider=gca_index_endpoint.DeployedIndexAuthConfig.AuthProvider(
audiences=_TEST_AUTH_CONFIG_AUDIENCES,
allowed_issuers=_TEST_AUTH_CONFIG_ALLOWED_ISSUERS,
)
),
),
]
get_index_public_endpoint_mock.return_value = index_endpoint
yield get_index_public_endpoint_mock


@pytest.fixture
def deploy_index_mock():
with patch.object(
Expand Down Expand Up @@ -556,7 +610,7 @@ def test_create_index_endpoint_with_network_init(self, create_index_endpoint_moc
metadata=_TEST_REQUEST_METADATA,
)

@pytest.mark.usefixtures("get_index_endpoint_mock")
@pytest.mark.usefixtures("get_index_public_endpoint_mock")
def test_create_index_endpoint_with_public_endpoint_enabled(
self, create_index_endpoint_mock
):
Expand All @@ -569,6 +623,10 @@ def test_create_index_endpoint_with_public_endpoint_enabled(
labels=_TEST_LABELS,
)

my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
)

expected = gca_index_endpoint.IndexEndpoint(
display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME,
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
Expand All @@ -582,6 +640,11 @@ def test_create_index_endpoint_with_public_endpoint_enabled(
metadata=_TEST_REQUEST_METADATA,
)

assert (
my_index_endpoint.public_endpoint_domain_name
== _TEST_INDEX_ENDPOINT_PUBLIC_DNS
)

def test_create_index_endpoint_missing_argument_throw_error(
self, create_index_endpoint_mock
):
Expand Down

0 comments on commit 1b5ae44

Please sign in to comment.