Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable TLS auth on py client #64

Merged
merged 3 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion clients/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ This library provides a high level interface for interacting with a model regist
```py
from model_registry import ModelRegistry

registry = ModelRegistry(server_address="server-address", port=9090, author="author")
registry = ModelRegistry("server-address", author="Ada Lovelace") # Defaults to a secure connection via port 443

# registry = ModelRegistry("server-address", 1234, author="Ada Lovelace", is_secure=False) # To use MR without TLS

model = registry.register_model(
"my-model", # model name
Expand Down
54 changes: 42 additions & 12 deletions clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

import os
from pathlib import Path
from typing import get_args
from warnings import warn

Expand All @@ -17,27 +19,55 @@ class ModelRegistry:
def __init__(
self,
server_address: str,
port: int,
port: int = 443,
*,
author: str,
client_key: str | None = None,
server_cert: str | None = None,
custom_ca: str | None = None,
is_secure: bool = True,
user_token: bytes | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a requirement right now, but we should open another issue to add support for client certificate as well for mutual authentication.

custom_ca: bytes | None = None,
):
"""Constructor.

Args:
server_address: Server address.
port: Server port.
port: Server port. Defaults to 443.

Keyword Args:
author: Name of the author.
client_key: The PEM-encoded private key as a byte string.
server_cert: The PEM-encoded certificate as a byte string.
custom_ca: The PEM-encoded root certificates as a byte string.
is_secure: Whether to use a secure connection. Defaults to True.
user_token: The PEM-encoded user token as a byte string. Defaults to content of path on envvar KF_PIPELINES_SA_TOKEN_PATH.
custom_ca: The PEM-encoded root certificates as a byte string. Defaults to contents of path on envvar CERT.
"""
# TODO: get args from env
# TODO: get remaining args from env
self._author = author
self._api = ModelRegistryAPIClient(
server_address, port, client_key, server_cert, custom_ca
)

if not user_token:
# /var/run/secrets/kubernetes.io/serviceaccount/token
sa_token = os.environ.get("KF_PIPELINES_SA_TOKEN_PATH")
if sa_token:
user_token = Path(sa_token).read_bytes()
else:
warn("User access token is missing", stacklevel=2)

if is_secure:
root_ca = None
if not custom_ca:
if ca_path := os.getenv("CERT"):
rareddy marked this conversation as resolved.
Show resolved Hide resolved
root_ca = Path(ca_path).read_bytes()
# client might have a default CA setup
else:
root_ca = custom_ca

self._api = ModelRegistryAPIClient.secure_connection(
server_address, port, user_token, root_ca
)
elif custom_ca:
msg = "Custom CA provided without secure connection"
raise StoreException(msg)
else:
self._api = ModelRegistryAPIClient.insecure_connection(
server_address, port, user_token
)

def _register_model(self, name: str) -> RegisteredModel:
if rm := self._api.get_registered_model_by_params(name):
Expand Down
118 changes: 75 additions & 43 deletions clients/python/src/model_registry/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,45 +2,79 @@

from __future__ import annotations

from ml_metadata.proto import MetadataStoreClientConfig
from dataclasses import dataclass

import grpc

from .exceptions import StoreException
from .store import MLMDStore, ProtoType
from .types import ListOptions, ModelArtifact, ModelVersion, RegisteredModel
from .types.base import ProtoBase
from .types.options import MLMDListOptions
from .utils import header_adder_interceptor


@dataclass
class ModelRegistryAPIClient:
"""Model registry API."""

def __init__(
self,
store: MLMDStore

@classmethod
def secure_connection(
cls,
server_address: str,
port: int = 443,
user_token: bytes | None = None,
custom_ca: bytes | None = None,
) -> ModelRegistryAPIClient:
"""Constructor.

Args:
server_address: Server address.
port: Server port. Defaults to 443.
user_token: The PEM-encoded user token as a byte string.
custom_ca: The PEM-encoded root certificates as a byte string. Defaults to GRPC_DEFAULT_SSL_ROOTS_FILE_PATH, then system default.
"""
if not user_token:
isinyaaa marked this conversation as resolved.
Show resolved Hide resolved
msg = "user token must be provided for secure connection"
raise StoreException(msg)

chan = grpc.secure_channel(
f"{server_address}:{port}",
grpc.composite_channel_credentials(
# custom_ca = None will get the default root certificates
grpc.ssl_channel_credentials(custom_ca),
grpc.access_token_call_credentials(user_token),
),
)

return cls(MLMDStore.from_channel(chan))

@classmethod
def insecure_connection(
cls,
server_address: str,
port: int,
client_key: str | None = None,
server_cert: str | None = None,
custom_ca: str | None = None,
):
user_token: bytes | None = None,
) -> ModelRegistryAPIClient:
"""Constructor.

Args:
server_address: Server address.
port: Server port.
client_key: The PEM-encoded private key as a byte string.
server_cert: The PEM-encoded certificate as a byte string.
custom_ca: The PEM-encoded root certificates as a byte string.
user_token: The PEM-encoded user token as a byte string.
"""
config = MetadataStoreClientConfig()
config.host = server_address
config.port = port
if client_key is not None:
config.ssl_config.client_key = client_key
if server_cert is not None:
config.ssl_config.server_cert = server_cert
if custom_ca is not None:
config.ssl_config.custom_ca = custom_ca
self._store = MLMDStore(config)
if user_token:
chan = grpc.intercept_channel(
grpc.insecure_channel(f"{server_address}:{port}"),
# header key has to be lowercase
header_adder_interceptor("authorization", f"Bearer {user_token}"),
)
else:
chan = grpc.insecure_channel(f"{server_address}:{port}")

return cls(MLMDStore.from_channel(chan))

def _map(self, py_obj: ProtoBase) -> ProtoType:
"""Map a Python object to a proto object.
Expand All @@ -53,7 +87,7 @@ def _map(self, py_obj: ProtoBase) -> ProtoType:
Returns:
Proto object.
"""
type_id = self._store.get_type_id(
type_id = self.store.get_type_id(
py_obj.get_proto_type(), py_obj.get_proto_type_name()
)
return py_obj.map(type_id)
Expand All @@ -70,9 +104,9 @@ def upsert_registered_model(self, registered_model: RegisteredModel) -> str:
Returns:
ID of the registered model.
"""
id = self._store.put_context(self._map(registered_model))
id = self.store.put_context(self._map(registered_model))
new_py_rm = RegisteredModel.unmap(
self._store.get_context(RegisteredModel.get_proto_type_name(), id)
self.store.get_context(RegisteredModel.get_proto_type_name(), id)
)
id = str(id)
registered_model.id = id
Expand All @@ -91,7 +125,7 @@ def get_registered_model_by_id(self, id: str) -> RegisteredModel | None:
Returns:
Registered model.
"""
proto_rm = self._store.get_context(
proto_rm = self.store.get_context(
RegisteredModel.get_proto_type_name(), id=int(id)
)
if proto_rm is not None:
Expand All @@ -117,7 +151,7 @@ def get_registered_model_by_params(
if name is None and external_id is None:
msg = "Either name or external_id must be provided"
raise StoreException(msg)
proto_rm = self._store.get_context(
proto_rm = self.store.get_context(
RegisteredModel.get_proto_type_name(),
name=name,
external_id=external_id,
Expand All @@ -139,7 +173,7 @@ def get_registered_models(
Registered models.
"""
mlmd_options = options.as_mlmd_list_options() if options else MLMDListOptions()
proto_rms = self._store.get_contexts(
proto_rms = self.store.get_contexts(
RegisteredModel.get_proto_type_name(), mlmd_options
)
return [RegisteredModel.unmap(proto_rm) for proto_rm in proto_rms]
Expand All @@ -161,10 +195,10 @@ def upsert_model_version(
"""
# this is not ideal but we need this info for the prefix
model_version._registered_model_id = registered_model_id
id = self._store.put_context(self._map(model_version))
self._store.put_context_parent(int(registered_model_id), id)
id = self.store.put_context(self._map(model_version))
self.store.put_context_parent(int(registered_model_id), id)
new_py_mv = ModelVersion.unmap(
self._store.get_context(ModelVersion.get_proto_type_name(), id)
self.store.get_context(ModelVersion.get_proto_type_name(), id)
)
id = str(id)
model_version.id = id
Expand All @@ -183,7 +217,7 @@ def get_model_version_by_id(self, model_version_id: str) -> ModelVersion | None:
Returns:
Model version.
"""
proto_mv = self._store.get_context(
proto_mv = self.store.get_context(
ModelVersion.get_proto_type_name(), id=int(model_version_id)
)
if proto_mv is not None:
Expand All @@ -207,7 +241,7 @@ def get_model_versions(
mlmd_options.filter_query = f"parent_contexts_a.id = {registered_model_id}"
return [
ModelVersion.unmap(proto_mv)
for proto_mv in self._store.get_contexts(
for proto_mv in self.store.get_contexts(
ModelVersion.get_proto_type_name(), mlmd_options
)
]
Expand All @@ -234,7 +268,7 @@ def get_model_version_by_params(
StoreException: If neither external ID nor registered model ID and version is provided.
"""
if external_id is not None:
proto_mv = self._store.get_context(
proto_mv = self.store.get_context(
ModelVersion.get_proto_type_name(), external_id=external_id
)
elif registered_model_id is None or version is None:
Expand All @@ -243,7 +277,7 @@ def get_model_version_by_params(
)
raise StoreException(msg)
else:
proto_mv = self._store.get_context(
proto_mv = self.store.get_context(
ModelVersion.get_proto_type_name(),
name=f"{registered_model_id}:{version}",
)
Expand Down Expand Up @@ -271,17 +305,17 @@ def upsert_model_artifact(
StoreException: If the model version already has a model artifact.
"""
mv_id = int(model_version_id)
if self._store.get_attributed_artifact(
if self.store.get_attributed_artifact(
ModelArtifact.get_proto_type_name(), mv_id
):
msg = f"Model version with ID {mv_id} already has a model artifact"
raise StoreException(msg)

model_artifact._model_version_id = model_version_id
id = self._store.put_artifact(self._map(model_artifact))
self._store.put_attribution(mv_id, id)
id = self.store.put_artifact(self._map(model_artifact))
self.store.put_attribution(mv_id, id)
new_py_ma = ModelArtifact.unmap(
self._store.get_artifact(ModelArtifact.get_proto_type_name(), id)
self.store.get_artifact(ModelArtifact.get_proto_type_name(), id)
)
id = str(id)
model_artifact.id = id
Expand All @@ -300,9 +334,7 @@ def get_model_artifact_by_id(self, id: str) -> ModelArtifact | None:
Returns:
Model artifact.
"""
proto_ma = self._store.get_artifact(
ModelArtifact.get_proto_type_name(), int(id)
)
proto_ma = self.store.get_artifact(ModelArtifact.get_proto_type_name(), int(id))
if proto_ma is not None:
return ModelArtifact.unmap(proto_ma)

Expand All @@ -324,14 +356,14 @@ def get_model_artifact_by_params(
StoreException: If neither external ID nor model version ID is provided.
"""
if external_id:
proto_ma = self._store.get_artifact(
proto_ma = self.store.get_artifact(
ModelArtifact.get_proto_type_name(), external_id=external_id
)
elif not model_version_id:
msg = "Either model_version_id or external_id must be provided"
raise StoreException(msg)
else:
proto_ma = self._store.get_attributed_artifact(
proto_ma = self.store.get_attributed_artifact(
ModelArtifact.get_proto_type_name(), int(model_version_id)
)
if proto_ma is not None:
Expand All @@ -357,7 +389,7 @@ def get_model_artifacts(
if model_version_id is not None:
mlmd_options.filter_query = f"contexts_a.id = {model_version_id}"

proto_mas = self._store.get_artifacts(
proto_mas = self.store.get_artifacts(
ModelArtifact.get_proto_type_name(), mlmd_options
)
return [ModelArtifact.unmap(proto_ma) for proto_ma in proto_mas]
Loading
Loading