Skip to content

Commit

Permalink
feat: use non-blocking disk read/writes (#360)
Browse files Browse the repository at this point in the history
Python's standard library read and writes which are blocking I/O.

This PR switches to use aiofiles which is non-blocking approach to
read/write to disk.
  • Loading branch information
jackwotherspoon authored Aug 6, 2024
1 parent 318445f commit ba434e7
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 23 deletions.
4 changes: 3 additions & 1 deletion google/cloud/alloydb/connector/async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,9 @@ def get_authentication_token() -> str:
if enable_iam_auth:
kwargs["password"] = get_authentication_token
try:
return await connector(ip_address, conn_info.create_ssl_context(), **kwargs)
return await connector(
ip_address, await conn_info.create_ssl_context(), **kwargs
)
except Exception:
# we attempt a force refresh, then throw the error
await cache.force_refresh()
Expand Down
9 changes: 5 additions & 4 deletions google/cloud/alloydb/connector/connection_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
from dataclasses import dataclass
import logging
import ssl
from tempfile import TemporaryDirectory
from typing import Dict, List, Optional, TYPE_CHECKING

from aiofiles.tempfile import TemporaryDirectory

from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError
from google.cloud.alloydb.connector.utils import _write_to_file

Expand All @@ -45,7 +46,7 @@ class ConnectionInfo:
expiration: datetime.datetime
context: Optional[ssl.SSLContext] = None

def create_ssl_context(self) -> ssl.SSLContext:
async def create_ssl_context(self) -> ssl.SSLContext:
"""Constructs a SSL/TLS context for the given connection info.
Cache the SSL context to ensure we don't read from disk repeatedly when
Expand All @@ -66,8 +67,8 @@ def create_ssl_context(self) -> ssl.SSLContext:
# tmpdir and its contents are automatically deleted after the CA cert
# and cert chain are loaded into the SSLcontext. The values
# need to be written to files in order to be loaded by the SSLContext
with TemporaryDirectory() as tmpdir:
ca_filename, cert_chain_filename, key_filename = _write_to_file(
async with TemporaryDirectory() as tmpdir:
ca_filename, cert_chain_filename, key_filename = await _write_to_file(
tmpdir, self.ca_cert, self.cert_chain, self.key
)
context.load_cert_chain(cert_chain_filename, keyfile=key_filename)
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/alloydb/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) ->
metadata_partial = partial(
self.metadata_exchange,
ip_address,
conn_info.create_ssl_context(),
await conn_info.create_ssl_context(),
enable_iam_auth,
driver,
)
Expand Down
15 changes: 8 additions & 7 deletions google/cloud/alloydb/connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@

from typing import List, Tuple

import aiofiles
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa


def _write_to_file(
async def _write_to_file(
dir_path: str, ca_cert: str, cert_chain: List[str], key: rsa.RSAPrivateKey
) -> Tuple[str, str, str]:
"""
Expand All @@ -37,12 +38,12 @@ def _write_to_file(
encryption_algorithm=serialization.NoEncryption(),
)

with open(ca_filename, "w+") as ca_out:
ca_out.write(ca_cert)
with open(cert_chain_filename, "w+") as chain_out:
chain_out.write("".join(cert_chain))
with open(key_filename, "wb") as priv_out:
priv_out.write(key_bytes)
async with aiofiles.open(ca_filename, "w+") as ca_out:
await ca_out.write(ca_cert)
async with aiofiles.open(cert_chain_filename, "w+") as chain_out:
await chain_out.write("".join(cert_chain))
async with aiofiles.open(key_filename, "wb") as priv_out:
await priv_out.write(key_bytes)

return (ca_filename, cert_chain_filename, key_filename)

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
aiofiles==24.1.0
aiohttp==3.9.5
cryptography==42.0.8
google-auth==2.32.0
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

release_status = "Development Status :: 5 - Production/Stable"
dependencies = [
"aiofiles",
"aiohttp",
"cryptography>=42.0.0",
"requests",
Expand Down
19 changes: 14 additions & 5 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import socket
import ssl
from tempfile import TemporaryDirectory
from threading import Thread
from typing import Generator

from aiofiles.tempfile import TemporaryDirectory
from mocks import FakeAlloyDBClient
from mocks import FakeCredentials
from mocks import FakeInstance
Expand All @@ -42,7 +43,7 @@ def fake_client(fake_instance: FakeInstance) -> FakeAlloyDBClient:
return FakeAlloyDBClient(fake_instance)


def start_proxy_server(instance: FakeInstance) -> None:
async def start_proxy_server(instance: FakeInstance) -> None:
"""Run local proxy server capable of performing metadata exchange"""
ip_address = "127.0.0.1"
port = 5433
Expand All @@ -55,8 +56,8 @@ def start_proxy_server(instance: FakeInstance) -> None:
# tmpdir and its contents are automatically deleted after the CA cert
# and cert chain are loaded into the SSLcontext. The values
# need to be written to files in order to be loaded by the SSLContext
with TemporaryDirectory() as tmpdir:
_, cert_chain_filename, key_filename = _write_to_file(
async with TemporaryDirectory() as tmpdir:
_, cert_chain_filename, key_filename = await _write_to_file(
tmpdir, server, [server, root], instance.server_key
)
context.load_cert_chain(cert_chain_filename, key_filename)
Expand All @@ -76,7 +77,15 @@ def start_proxy_server(instance: FakeInstance) -> None:
@pytest.fixture(scope="session")
def proxy_server(fake_instance: FakeInstance) -> Generator:
"""Run local proxy server capable of performing metadata exchange"""
thread = Thread(target=start_proxy_server, args=(fake_instance,), daemon=True)
thread = Thread(
target=asyncio.run,
args=(
start_proxy_server(
fake_instance,
),
),
daemon=True,
)
thread.start()
yield thread
thread.join()
2 changes: 1 addition & 1 deletion tests/unit/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def get_preferred_ip(self, ip_type: Any) -> Tuple[str, Any]:
f.set_result("10.0.0.1")
return f

def create_ssl_context(self) -> None:
async def create_ssl_context(self) -> None:
return None

async def force_refresh(self) -> None:
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/test_connection_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError


def test_ConnectionInfo_init_(fake_instance: FakeInstance) -> None:
async def test_ConnectionInfo_init_(fake_instance: FakeInstance) -> None:
"""
Test to check whether the __init__ method of ConnectionInfo
can correctly initialize TLS context.
Expand Down Expand Up @@ -58,19 +58,19 @@ def test_ConnectionInfo_init_(fake_instance: FakeInstance) -> None:
fake_instance.ip_addrs,
datetime.now(timezone.utc) + timedelta(minutes=10),
)
context = conn_info.create_ssl_context()
context = await conn_info.create_ssl_context()
# verify TLS requirements
assert context.minimum_version == ssl.TLSVersion.TLSv1_3


def test_ConnectionInfo_caches_sslcontext() -> None:
async def test_ConnectionInfo_caches_sslcontext() -> None:
info = ConnectionInfo(["cert"], "cert", "key".encode(), {}, datetime.now())
# context should default to None
assert info.context is None
# cache a 'context'
info.context = "context"
# calling create_ssl_context should no-op with an existing 'context'
info.create_ssl_context()
await info.create_ssl_context()
assert info.context == "context"


Expand Down

0 comments on commit ba434e7

Please sign in to comment.