Skip to content

Commit

Permalink
feat: support ip_type as str (#267)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwotherspoon authored Mar 7, 2024
1 parent 9782f6e commit b7b1d99
Show file tree
Hide file tree
Showing 9 changed files with 185 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.

[flake8]
ignore = E203, E231, E266, E501, W503, ANN101, ANN401
ignore = E203, E231, E266, E501, W503, ANN101, ANN102, ANN401
exclude =
# Exclude generated code.
**/proto/**
Expand Down
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -381,12 +381,11 @@ to your instance's private IP. To change this, such as connecting to AlloyDB
over a public IP address, set the `ip_type` keyword argument when initializing
a `Connector()` or when calling `connector.connect()`.

Possible values for `ip_type` are `IPTypes.PRIVATE` (default value), and
`IPTypes.PUBLIC`.
Possible values for `ip_type` are `"PRIVATE"` (default value), and `"PUBLIC"`.
Example:

```python
from google.cloud.alloydb.connector import Connector, IPTypes
from google.cloud.alloydb.connector import Connector

import sqlalchemy

Expand All @@ -401,7 +400,7 @@ def getconn():
user="my-user",
password="my-password",
db="my-db-name",
ip_type=IPTypes.PUBLIC, # use public IP
ip_type="PUBLIC", # use public IP
)

# create connection pool
Expand Down
14 changes: 10 additions & 4 deletions google/cloud/alloydb/connector/async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class AsyncConnector:
alloydb_api_endpoint (str): Base URL to use when calling
the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com".
enable_iam_auth (bool): Enables automatic IAM database authentication.
ip_type (IPTypes): Default IP type for all AlloyDB connections.
Defaults to IPTypes.PRIVATE for private IP connections.
ip_type (str | IPTypes): Default IP type for all AlloyDB connections.
Defaults to IPTypes.PRIVATE ("PRIVATE") for private IP connections.
"""

def __init__(
Expand All @@ -57,14 +57,17 @@ def __init__(
quota_project: Optional[str] = None,
alloydb_api_endpoint: str = "https://alloydb.googleapis.com",
enable_iam_auth: bool = False,
ip_type: IPTypes = IPTypes.PRIVATE,
ip_type: str | IPTypes = IPTypes.PRIVATE,
user_agent: Optional[str] = None,
) -> None:
self._instances: Dict[str, Instance] = {}
# initialize default params
self._quota_project = quota_project
self._alloydb_api_endpoint = alloydb_api_endpoint
self._enable_iam_auth = enable_iam_auth
# if ip_type is str, convert to IPTypes enum
if isinstance(ip_type, str):
ip_type = IPTypes(ip_type.upper())
self._ip_type = ip_type
self._user_agent = user_agent
# initialize credentials
Expand Down Expand Up @@ -144,7 +147,10 @@ async def connect(
kwargs.pop("port", None)

# get connection info for AlloyDB instance
ip_type: IPTypes = kwargs.pop("ip_type", self._ip_type)
ip_type: str | IPTypes = kwargs.pop("ip_type", self._ip_type)
# if ip_type is str, convert to IPTypes enum
if isinstance(ip_type, str):
ip_type = IPTypes(ip_type.upper())
ip_address, context = await instance.connection_info(ip_type)

# callable to be used for auto IAM authn
Expand Down
14 changes: 10 additions & 4 deletions google/cloud/alloydb/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ class Connector:
alloydb_api_endpoint (str): Base URL to use when calling
the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com".
enable_iam_auth (bool): Enables automatic IAM database authentication.
ip_type (IPTypes): Default IP type for all AlloyDB connections.
Defaults to IPTypes.PRIVATE for private IP connections.
ip_type (str | IPTypes): Default IP type for all AlloyDB connections.
Defaults to IPTypes.PRIVATE ("PRIVATE") for private IP connections.
"""

def __init__(
Expand All @@ -67,7 +67,7 @@ def __init__(
quota_project: Optional[str] = None,
alloydb_api_endpoint: str = "https://alloydb.googleapis.com",
enable_iam_auth: bool = False,
ip_type: IPTypes = IPTypes.PRIVATE,
ip_type: str | IPTypes = IPTypes.PRIVATE,
user_agent: Optional[str] = None,
) -> None:
# create event loop and start it in background thread
Expand All @@ -79,6 +79,9 @@ def __init__(
self._quota_project = quota_project
self._alloydb_api_endpoint = alloydb_api_endpoint
self._enable_iam_auth = enable_iam_auth
# if ip_type is str, convert to IPTypes enum
if isinstance(ip_type, str):
ip_type = IPTypes(ip_type.upper())
self._ip_type = ip_type
self._user_agent = user_agent
# initialize credentials
Expand Down Expand Up @@ -171,7 +174,10 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) ->
kwargs.pop("port", None)

# get connection info for AlloyDB instance
ip_type: IPTypes = kwargs.pop("ip_type", self._ip_type)
ip_type: IPTypes | str = kwargs.pop("ip_type", self._ip_type)
# if ip_type is str, convert to IPTypes enum
if isinstance(ip_type, str):
ip_type = IPTypes(ip_type.upper())
ip_address, context = await instance.connection_info(ip_type)

# synchronous drivers are blocking and run using executor
Expand Down
7 changes: 7 additions & 0 deletions google/cloud/alloydb/connector/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ class IPTypes(Enum):
PUBLIC: str = "PUBLIC"
PRIVATE: str = "PRIVATE"

@classmethod
def _missing_(cls, value: object) -> None:
raise ValueError(
f"Incorrect value for ip_type, got '{value}'. Want one of: "
f"{', '.join([repr(m.value) for m in cls])}."
)


def _parse_instance_uri(instance_uri: str) -> Tuple[str, str, str, str]:
# should take form "projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE>"
Expand Down
3 changes: 1 addition & 2 deletions tests/system/test_asyncpg_public_ip.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import sqlalchemy.ext.asyncio

from google.cloud.alloydb.connector import AsyncConnector
from google.cloud.alloydb.connector import IPTypes


async def create_sqlalchemy_engine(
Expand Down Expand Up @@ -70,7 +69,7 @@ async def getconn() -> asyncpg.Connection:
user=user,
password=password,
db=db,
ip_type=IPTypes.PUBLIC,
ip_type="PUBLIC",
)
return conn

Expand Down
3 changes: 1 addition & 2 deletions tests/system/test_pg8000_public_ip.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import sqlalchemy

from google.cloud.alloydb.connector import Connector
from google.cloud.alloydb.connector import IPTypes


def create_sqlalchemy_engine(
Expand Down Expand Up @@ -70,7 +69,7 @@ def getconn() -> pg8000.dbapi.Connection:
user=user,
password=password,
db=db,
ip_type=IPTypes.PUBLIC,
ip_type="PUBLIC",
)
return conn

Expand Down
76 changes: 76 additions & 0 deletions tests/unit/test_async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import asyncio
from typing import Union

from mock import patch
from mocks import FakeAlloyDBClient
Expand All @@ -21,6 +22,7 @@
import pytest

from google.cloud.alloydb.connector import AsyncConnector
from google.cloud.alloydb.connector import IPTypes

ALLOYDB_API_ENDPOINT = "https://alloydb.googleapis.com"

Expand All @@ -40,6 +42,58 @@ async def test_AsyncConnector_init(credentials: FakeCredentials) -> None:
await connector.close()


@pytest.mark.parametrize(
"ip_type, expected",
[
(
"private",
IPTypes.PRIVATE,
),
(
"PRIVATE",
IPTypes.PRIVATE,
),
(
IPTypes.PRIVATE,
IPTypes.PRIVATE,
),
(
"public",
IPTypes.PUBLIC,
),
(
"PUBLIC",
IPTypes.PUBLIC,
),
(
IPTypes.PUBLIC,
IPTypes.PUBLIC,
),
],
)
async def test_AsyncConnector_init_ip_type(
ip_type: Union[str, IPTypes], expected: IPTypes, credentials: FakeCredentials
) -> None:
"""
Test to check whether the __init__ method of AsyncConnector
properly sets ip_type.
"""
connector = AsyncConnector(credentials=credentials, ip_type=ip_type)
assert connector._ip_type == expected
connector.close()


async def test_AsyncConnector_init_bad_ip_type(credentials: FakeCredentials) -> None:
"""Test that AsyncConnector errors due to bad ip_type str."""
bad_ip_type = "BAD-IP-TYPE"
with pytest.raises(ValueError) as exc_info:
AsyncConnector(ip_type=bad_ip_type, credentials=credentials)
assert (
exc_info.value.args[0]
== f"Incorrect value for ip_type, got '{bad_ip_type}'. Want one of: 'PUBLIC', 'PRIVATE'."
)


@pytest.mark.asyncio
async def test_AsyncConnector_context_manager(
credentials: FakeCredentials,
Expand Down Expand Up @@ -202,3 +256,25 @@ def test_synchronous_init(credentials: FakeCredentials) -> None:
"""
connector = AsyncConnector(credentials)
assert connector._keys is None


async def test_async_connect_bad_ip_type(
credentials: FakeCredentials, fake_client: FakeAlloyDBClient
) -> None:
"""Test that AyncConnector.connect errors due to bad ip_type str."""
async with AsyncConnector(credentials=credentials) as connector:
connector._client = fake_client
bad_ip_type = "BAD-IP-TYPE"
with pytest.raises(ValueError) as exc_info:
await connector.connect(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
"asyncpg",
user="test-user",
password="test-password",
db="test-db",
ip_type=bad_ip_type,
)
assert (
exc_info.value.args[0]
== f"Incorrect value for ip_type, got '{bad_ip_type}'. Want one of: 'PUBLIC', 'PRIVATE'."
)
76 changes: 76 additions & 0 deletions tests/unit/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

import asyncio
from threading import Thread
from typing import Union

from mock import patch
from mocks import FakeAlloyDBClient
from mocks import FakeCredentials
import pytest

from google.cloud.alloydb.connector import Connector
from google.cloud.alloydb.connector import IPTypes


def test_Connector_init(credentials: FakeCredentials) -> None:
Expand All @@ -36,6 +38,58 @@ def test_Connector_init(credentials: FakeCredentials) -> None:
connector.close()


def test_Connector_init_bad_ip_type(credentials: FakeCredentials) -> None:
"""Test that Connector errors due to bad ip_type str."""
bad_ip_type = "BAD-IP-TYPE"
with pytest.raises(ValueError) as exc_info:
Connector(ip_type=bad_ip_type, credentials=credentials)
assert (
exc_info.value.args[0]
== f"Incorrect value for ip_type, got '{bad_ip_type}'. Want one of: 'PUBLIC', 'PRIVATE'."
)


@pytest.mark.parametrize(
"ip_type, expected",
[
(
"private",
IPTypes.PRIVATE,
),
(
"PRIVATE",
IPTypes.PRIVATE,
),
(
IPTypes.PRIVATE,
IPTypes.PRIVATE,
),
(
"public",
IPTypes.PUBLIC,
),
(
"PUBLIC",
IPTypes.PUBLIC,
),
(
IPTypes.PUBLIC,
IPTypes.PUBLIC,
),
],
)
def test_Connector_init_ip_type(
ip_type: Union[str, IPTypes], expected: IPTypes, credentials: FakeCredentials
) -> None:
"""
Test to check whether the __init__ method of Connector
properly sets ip_type.
"""
connector = Connector(credentials=credentials, ip_type=ip_type)
assert connector._ip_type == expected
connector.close()


def test_Connector_context_manager(credentials: FakeCredentials) -> None:
"""
Test to check whether the __init__ method of Connector
Expand Down Expand Up @@ -84,6 +138,28 @@ def test_connect(credentials: FakeCredentials, fake_client: FakeAlloyDBClient) -
assert connection is True


def test_connect_bad_ip_type(
credentials: FakeCredentials, fake_client: FakeAlloyDBClient
) -> None:
"""Test that Connector.connect errors due to bad ip_type str."""
with Connector(credentials=credentials) as connector:
connector._client = fake_client
bad_ip_type = "BAD-IP-TYPE"
with pytest.raises(ValueError) as exc_info:
connector.connect(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
"pg8000",
user="test-user",
password="test-password",
db="test-db",
ip_type=bad_ip_type,
)
assert (
exc_info.value.args[0]
== f"Incorrect value for ip_type, got '{bad_ip_type}'. Want one of: 'PUBLIC', 'PRIVATE'."
)


def test_connect_unsupported_driver(credentials: FakeCredentials) -> None:
"""
Test that connector.connect errors with unsupported database driver.
Expand Down

0 comments on commit b7b1d99

Please sign in to comment.