Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: OBS-435 - diode-sdk-python: tls_verify + grpc client call interceptor #75

Merged
merged 6 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 151 additions & 16 deletions diode-sdk-python/netboxlabs/diode/sdk/client.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,70 @@
#!/usr/bin/env python
# Copyright 2024 NetBox Labs Inc
"""NetBox Labs, Diode - SDK - Client."""

import collections
import logging
import os
import platform
import uuid
from typing import Iterable, Optional
from typing import Dict, Iterable, Optional

import certifi
import grpc

from netboxlabs.diode.sdk.diode.v1 import ingester_pb2, ingester_pb2_grpc
from netboxlabs.diode.sdk.exceptions import DiodeClientError, DiodeConfigError

_DIODE_API_KEY_ENVVAR_NAME = "DIODE_API_KEY"
_DIODE_API_TLS_VERIFY_ENVVAR_NAME = "DIODE_API_TLS_VERIFY"
_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME = "DIODE_SDK_LOG_LEVEL"
_DEFAULT_STREAM = "latest"
_LOGGER = logging.getLogger(__name__)


def _certs() -> bytes:
with open(certifi.where(), "rb") as f:
return f.read()


def _api_key(api_key: Optional[str] = None) -> str:
if api_key is None:
api_key = os.getenv(_DIODE_API_KEY_ENVVAR_NAME)
if api_key is None:
raise DiodeConfigError(
f"api_key param or {_DIODE_API_KEY_ENVVAR_NAME} environment variable required"
)
return api_key


def _tls_verify(tls_verify: Optional[bool]) -> bool:
if tls_verify is None:
tls_verify_env_var = os.getenv(_DIODE_API_TLS_VERIFY_ENVVAR_NAME, "false")
return tls_verify_env_var.lower() in ["true", "1", "yes"]
if not isinstance(tls_verify, bool):
raise DiodeConfigError("tls_verify must be a boolean")

return tls_verify


def parse_target(target: str) -> Dict[str, str]:
"""Parse target."""
if target.startswith(("http://", "https://")):
raise ValueError("target should not contain http:// or https://")

parts = [str(part) for part in target.split("/") if part != ""]

authority = ":".join([str(part) for part in parts[0].split(":") if part != ""])

if ":" not in authority:
authority += ":443"

path = ""
if len(parts) > 1:
path = "/" + "/".join(parts[1:])

return authority, path


class DiodeClient:
"""Diode Client."""

Expand All @@ -34,27 +81,48 @@ def __init__(
app_name: str,
app_version: str,
api_key: Optional[str] = None,
tls_verify: bool = None,
):
"""Initiate a new client."""
log_level = os.getenv(_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME, "INFO").upper()
logging.basicConfig(level=log_level)

# TODO: validate target
self._target = target

self._target, self._path = parse_target(target)
self._app_name = app_name
self._app_version = app_version

if api_key is None:
api_key = os.getenv(_DIODE_API_KEY_ENVVAR_NAME)
if api_key is None:
raise DiodeConfigError("API key is required")
api_key = _api_key(api_key)
self._metadata = (
("diode-api-key", api_key),
("platform", platform.platform()),
("python-version", platform.python_version()),
)

self._tls_verify = _tls_verify(tls_verify)

if self._tls_verify:
self._channel = grpc.secure_channel(
self._target,
grpc.ssl_channel_credentials(
root_certificates=_certs(),
),
)
else:
self._channel = grpc.insecure_channel(
target=self._target,
)

channel = self._channel

if self._path:
rpc_method_interceptor = DiodeMethodClientInterceptor(subpath=self._path)

intercept_channel = grpc.intercept_channel(
self._channel, rpc_method_interceptor
)
channel = intercept_channel

self._auth_metadata = (("diode-api-key", api_key),)
# TODO: add support for secure channel (TLS verify flag and cert)
self._channel = grpc.insecure_channel(target)
self._stub = ingester_pb2_grpc.IngesterServiceStub(self._channel)
# TODO: obtain meta data about the environment; Python version, CPU arch, OS
self._stub = ingester_pb2_grpc.IngesterServiceStub(channel)

@property
def name(self) -> str:
Expand All @@ -71,6 +139,16 @@ def target(self) -> str:
"""Retrieve the target."""
return self._target

@property
def path(self) -> str:
"""Retrieve the path."""
return self._path

@property
def tls_verify(self) -> str:
"""Retrieve the tls_verify."""
return self._tls_verify

@property
def app_name(self) -> str:
"""Retrieve the app name."""
Expand Down Expand Up @@ -103,7 +181,7 @@ def ingest(
entities: Iterable[ingester_pb2.Entity],
stream: Optional[str] = _DEFAULT_STREAM,
) -> ingester_pb2.IngestResponse:
"""Push a message."""
"""Ingest entities."""
try:
request = ingester_pb2.IngestRequest(
stream=stream,
Expand All @@ -115,6 +193,63 @@ def ingest(
producer_app_version=self.app_version,
)

return self._stub.Ingest(request, metadata=self._auth_metadata)
return self._stub.Ingest(request, metadata=self._metadata)
except grpc.RpcError as err:
raise DiodeClientError(err) from err


class _ClientCallDetails(
collections.namedtuple(
"_ClientCallDetails",
(
"method",
"timeout",
"metadata",
"credentials",
"wait_for_ready",
"compression",
),
),
grpc.ClientCallDetails,
):
"""Client Call Details."""

pass


class DiodeMethodClientInterceptor(
grpc.UnaryUnaryClientInterceptor, grpc.StreamUnaryClientInterceptor
):
"""Diode Method Client Interceptor."""

def __init__(self, subpath):
"""Initiate a new interceptor."""
self._subpath = subpath

def _intercept_call(self, continuation, client_call_details, request_or_iterator):
"""Intercept call."""
method = client_call_details.method
if client_call_details.method is not None:
method = f"{self._subpath}{client_call_details.method}"

client_call_details = _ClientCallDetails(
method,
client_call_details.timeout,
client_call_details.metadata,
client_call_details.credentials,
client_call_details.wait_for_ready,
client_call_details.compression,
)

response = continuation(client_call_details, request_or_iterator)
return response

def intercept_unary_unary(self, continuation, client_call_details, request):
"""Intercept unary unary."""
return self._intercept_call(continuation, client_call_details, request)

def intercept_stream_unary(
self, continuation, client_call_details, request_iterator
):
"""Intercept stream unary."""
return self._intercept_call(continuation, client_call_details, request_iterator)
1 change: 1 addition & 0 deletions diode-sdk-python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ classifiers = [ # Optional
]

dependencies = [
"certifi==2024.2.2",
"grpcio==1.62.1",
"grpcio-status==1.62.1",
]
Expand Down
24 changes: 20 additions & 4 deletions diode-sdk-python/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,40 @@

def test_init():
"""Ensure we can initiate a client configuration."""
config = DiodeClient(target="localhost:8081", app_name="my-producer", app_version="0.0.1", api_key="abcde")
config = DiodeClient(
target="localhost:8081",
app_name="my-producer",
app_version="0.0.1",
api_key="abcde",
)
assert config.target == "localhost:8081"
assert config.name == "diode-sdk-python"
assert config.version == "0.0.1"
assert config.app_name == "my-producer"
assert config.app_version == "0.0.1"
assert config.tls_verify is False
assert config.path == ""


def test_config_error():
"""Ensure we can raise a config error."""
with pytest.raises(DiodeConfigError) as err:
DiodeClient(target="localhost:8081", app_name="my-producer", app_version="0.0.1")
assert str(err.value) == "API key is required"
DiodeClient(
target="localhost:8081", app_name="my-producer", app_version="0.0.1"
)
assert (
str(err.value) == "api_key param or DIODE_API_KEY environment variable required"
)


def test_client_error():
"""Ensure we can raise a client error."""
with pytest.raises(DiodeClientError) as err:
client = DiodeClient(target="invalid:8081", app_name="my-producer", app_version="0.0.1", api_key="abcde")
client = DiodeClient(
target="invalid:8081",
app_name="my-producer",
app_version="0.0.1",
api_key="abcde",
)
client.ingest(entities=[])
assert err.value.status_code == grpc.StatusCode.UNAVAILABLE
Loading