Skip to content

Commit

Permalink
feat: support user provided api endpoint.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 583223550
  • Loading branch information
sasha-gitg authored and copybara-github committed Nov 17, 2023
1 parent 8562368 commit 92f2b4e
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 25 deletions.
52 changes: 35 additions & 17 deletions google/cloud/aiplatform/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(self):
self._encryption_spec_key_name = None
self._network = None
self._service_account = None
self._api_endpoint = None

def init(
self,
Expand All @@ -119,6 +120,7 @@ def init(
encryption_spec_key_name: Optional[str] = None,
network: Optional[str] = None,
service_account: Optional[str] = None,
api_endpoint: Optional[str] = None,
):
"""Updates common initialization parameters with provided options.
Expand Down Expand Up @@ -174,11 +176,17 @@ def init(
PipelineJob, HyperparameterTuningJob, CustomTrainingJob,
CustomPythonPackageTrainingJob, CustomContainerTrainingJob,
ModelEvaluationJob.
api_endpoint (str):
Optional. The desired API endpoint,
e.g., us-central1-aiplatform.googleapis.com
Raises:
ValueError:
If experiment_description is provided but experiment is not.
"""

if api_endpoint is not None:
self._api_endpoint = api_endpoint

if experiment_description and experiment is None:
raise ValueError(
"Experiment needs to be set in `init` in order to add experiment descriptions."
Expand Down Expand Up @@ -252,6 +260,11 @@ def get_encryption_spec(
)
return encryption_spec

@property
def api_endpoint(self) -> Optional[str]:
"""Default API endpoint, if provided."""
return self._api_endpoint

@property
def project(self) -> str:
"""Default project."""
Expand Down Expand Up @@ -351,27 +364,32 @@ def get_client_options(
{ "api_endpoint": "us-central1-aiplatform.googleapis.com" } or
{ "api_endpoint": "asia-east1-aiplatform.googleapis.com" }
"""
if not (self.location or location_override):
raise ValueError(
"No location found. Provide or initialize SDK with a location."
)

region = location_override or self.location
region = region.lower()
api_endpoint = self.api_endpoint

utils.validate_region(region)
if api_endpoint is None:
if not (self.location or location_override):
raise ValueError(
"No location found. Provide or initialize SDK with a location."
)

service_base_path = api_base_path_override or (
constants.PREDICTION_API_BASE_PATH
if prediction_client
else constants.API_BASE_PATH
)
region = location_override or self.location
region = region.lower()

utils.validate_region(region)

service_base_path = api_base_path_override or (
constants.PREDICTION_API_BASE_PATH
if prediction_client
else constants.API_BASE_PATH
)

api_endpoint = (
f"{region}-{service_base_path}"
if not api_path_override
else api_path_override
)

api_endpoint = (
f"{region}-{service_base_path}"
if not api_path_override
else api_path_override
)
return client_options.ClientOptions(api_endpoint=api_endpoint)

def common_location_path(
Expand Down
50 changes: 42 additions & 8 deletions tests/unit/aiplatform/test_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,19 @@

import importlib
import os
import pytest
from typing import Optional
from unittest import mock
from unittest.mock import patch

import pytest

import google.auth
from google.auth import credentials

from google.cloud.aiplatform import initializer
from google.cloud.aiplatform.metadata.metadata import _experiment_tracker
from google.cloud.aiplatform.constants import base as constants
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.utils import resource_manager_utils

from google.cloud.aiplatform.compat.services import (
model_service_client,
)
Expand Down Expand Up @@ -307,30 +307,64 @@ def test_create_client_appended_user_agent(self):
assert " " + appended_user_agent[0] in user_agent
assert " " + appended_user_agent[1] in user_agent

def test_set_api_endpoint(self):
initializer.global_config.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
api_endpoint="test.googleapis.com",
)

assert initializer.global_config.api_endpoint == "test.googleapis.com"

def test_not_set_api_endpoint(self):
initializer.global_config.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)

assert initializer.global_config.api_endpoint is None

@pytest.mark.parametrize(
"init_location, location_override, expected_endpoint",
"init_location, location_override, api_endpoint, expected_endpoint",
[
("us-central1", None, "us-central1-aiplatform.googleapis.com"),
("us-central1", None, None, "us-central1-aiplatform.googleapis.com"),
(
"us-central1",
"europe-west4",
None,
"europe-west4-aiplatform.googleapis.com",
),
("asia-east1", None, "asia-east1-aiplatform.googleapis.com"),
("asia-east1", None, None, "asia-east1-aiplatform.googleapis.com"),
(
"asia-southeast1",
"australia-southeast1",
None,
"australia-southeast1-aiplatform.googleapis.com",
),
(
"asia-east1",
None,
"us-central1-aiplatform.googleapis.com",
"us-central1-aiplatform.googleapis.com",
),
(
"us-central1",
None,
"test.aiplatform.googleapis.com",
"test.aiplatform.googleapis.com",
),
],
)
def test_get_client_options(
self,
init_location: str,
location_override: str,
location_override: Optional[str],
api_endpoint: Optional[str],
expected_endpoint: str,
):
initializer.global_config.init(location=init_location)
initializer.global_config.init(
location=init_location, api_endpoint=api_endpoint
)

assert (
initializer.global_config.get_client_options(
Expand Down

0 comments on commit 92f2b4e

Please sign in to comment.