diff --git a/nilai-models/src/nilai_models/model.py b/nilai-models/src/nilai_models/model.py index c393e8a0..2e36b170 100644 --- a/nilai-models/src/nilai_models/model.py +++ b/nilai-models/src/nilai_models/model.py @@ -50,15 +50,52 @@ def __init__(self, metadata: ModelMetadata, prefix="/models"): def setup_app(self): @asynccontextmanager async def lifespan(app: FastAPI): - # Load the model on the API - discovery_service = ModelServiceDiscovery( - host=SETTINGS["etcd_host"], port=SETTINGS["etcd_port"] - ) - lease = await discovery_service.register_model(self.endpoint, self.prefix) - asyncio.create_task(discovery_service.keep_alive(lease)) - logger.info(f"Registered model endpoint: {self.endpoint}") - yield - await discovery_service.unregister_model(self.endpoint.metadata.id) + discovery_service = None + keep_alive_task = None + + try: + # Initialize discovery service + discovery_service = ModelServiceDiscovery( + host=SETTINGS["etcd_host"], port=SETTINGS["etcd_port"] + ) + + # Validate model metadata + if not self.endpoint or not self.endpoint.metadata: + raise ValueError("Invalid model metadata") + + # Register model and start keepalive + logger.info(f"Registering model: {self.endpoint.metadata.id}") + lease = await discovery_service.register_model( + self.endpoint, self.prefix + ) + keep_alive_task = asyncio.create_task( + discovery_service.keep_alive(lease), + name=f"keepalive_{self.endpoint.metadata.id}", + ) + + logger.info(f"Model registered successfully: {self.endpoint}") + yield + + except Exception as e: + logger.error(f"Failed to initialize model service: {e}") + raise + finally: + # Cleanup + if keep_alive_task and not keep_alive_task.done(): + keep_alive_task.cancel() + try: + await keep_alive_task + except asyncio.CancelledError: + pass + + if discovery_service: + try: + await discovery_service.unregister_model( + self.endpoint.metadata.id + ) + logger.info(f"Model unregistered: {self.endpoint.metadata.id}") + except Exception as e: + logger.error(f"Error unregistering model: {e}") # Create a FastAPI application instance for the model self.app = FastAPI(lifespan=lifespan) diff --git a/packages/nilai-common/pyproject.toml b/packages/nilai-common/pyproject.toml index 4ff7e6c7..19a554a4 100644 --- a/packages/nilai-common/pyproject.toml +++ b/packages/nilai-common/pyproject.toml @@ -10,6 +10,7 @@ requires-python = ">=3.12" dependencies = [ "etcd3gw>=2.4.2", "pydantic>=2.10.1", + "tenacity>=9.0.0", ] [build-system] diff --git a/packages/nilai-common/src/nilai_common/discovery.py b/packages/nilai-common/src/nilai_common/discovery.py index 8f58452b..6c670540 100644 --- a/packages/nilai-common/src/nilai_common/discovery.py +++ b/packages/nilai-common/src/nilai_common/discovery.py @@ -1,10 +1,20 @@ import asyncio +import logging from typing import Dict, Optional +from asyncio import CancelledError +from datetime import datetime +from tenacity import retry, wait_exponential, stop_after_attempt + + from etcd3gw import Lease from etcd3gw.client import Etcd3Client from nilai_common.api_model import ModelEndpoint, ModelMetadata +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + class ModelServiceDiscovery: def __init__(self, host: str = "localhost", port: int = 2379, lease_ttl: int = 60): @@ -15,8 +25,22 @@ def __init__(self, host: str = "localhost", port: int = 2379, lease_ttl: int = 6 :param port: etcd server port :param lease_ttl: Lease time for endpoint registration (in seconds) """ - self.client = Etcd3Client(host=host, port=port) + self.host = host + self.port = port self.lease_ttl = lease_ttl + self.initialize() + + self.is_healthy = True + self.last_refresh = None + self.max_retries = 3 + self.base_delay = 1 + self._shutdown = False + + def initialize(self): + """ + Initialize the etcd client. + """ + self.client = Etcd3Client(host=self.host, port=self.port) async def register_model( self, model_endpoint: ModelEndpoint, prefix: str = "/models" @@ -73,7 +97,7 @@ async def discover_models( discovered_models[model_endpoint.metadata.id] = model_endpoint except Exception as e: - print(f"Error parsing model endpoint: {e}") + logger.error(f"Error parsing model endpoint: {e}") return discovered_models async def get_model( @@ -101,19 +125,36 @@ async def unregister_model(self, model_id: str): key = f"/models/{model_id}" self.client.delete(key) - async def keep_alive(self, lease: Lease): - """ - Keep the model registration lease alive. - - :param lease_id: Lease ID to keep alive - """ - while True: - try: - lease.refresh() - await asyncio.sleep(self.lease_ttl // 2) - except Exception as e: - print(f"Lease keepalive failed: {e}") - break + @retry( + wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3) + ) + async def _refresh_lease(self, lease): + lease.refresh() + self.last_refresh = datetime.now() + self.is_healthy = True + + async def keep_alive(self, lease): + """Keep the model registration lease alive with graceful shutdown.""" + try: + while not self._shutdown: + try: + await self._refresh_lease(lease) + await asyncio.sleep(self.lease_ttl // 2) + except Exception as e: + self.is_healthy = False + logger.error(f"Lease keepalive failed: {e}") + try: + self.initialize() + lease.client = self.client + except Exception as init_error: + logger.error(f"Reinitialization failed: {init_error}") + await asyncio.sleep(self.base_delay) + except CancelledError: + logger.info("Lease keepalive task cancelled, shutting down...") + self._shutdown = True + raise + finally: + self.is_healthy = False # Example usage @@ -146,11 +187,11 @@ async def main(): discovered_models = await service_discovery.discover_models( name="Image Classification", feature="image_classification" ) - print("FOUND: ", len(discovered_models)) + logger.info(f"FOUND: {len(discovered_models)}") for model in discovered_models.values(): - print(f"Discovered Model: {model.metadata.id}") - print(f"URL: {model.url}") - print(f"Supported Features: {model.metadata.supported_features}") + logger.info(f"Discovered Model: {model.metadata.id}") + logger.info(f"URL: {model.url}") + logger.info(f"Supported Features: {model.metadata.supported_features}") # Optional: Keep the service running await asyncio.sleep(10) # Keep running for an hour @@ -158,11 +199,11 @@ async def main(): discovered_models = await service_discovery.discover_models( name="Image Classification", feature="image_classification" ) - print("FOUND: ", len(discovered_models)) + logger.info(f"FOUND: {len(discovered_models)}") for model in discovered_models.values(): - print(f"Discovered Model: {model.metadata.id}") - print(f"URL: {model.url}") - print(f"Supported Features: {model.metadata.supported_features}") + logger.info(f"Discovered Model: {model.metadata.id}") + logger.info(f"URL: {model.url}") + logger.info(f"Supported Features: {model.metadata.supported_features}") # Cleanup await service_discovery.unregister_model(model_endpoint.metadata.id) diff --git a/packages/nilai-common/src/nilai_common/logger.py b/packages/nilai-common/src/nilai_common/logger.py new file mode 100644 index 00000000..a07fc64d --- /dev/null +++ b/packages/nilai-common/src/nilai_common/logger.py @@ -0,0 +1,39 @@ +import logging +import sys +from typing import Optional +from pathlib import Path + + +def setup_logger( + name: str, + level: int = logging.INFO, + log_file: Optional[Path] = None, +) -> logging.Logger: + """Configure common logger for Nilai services.""" + + # Create logger with service name + logger = logging.getLogger(name) + logger.setLevel(level) + + # Create formatter + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + + # Console handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + # File handler if path provided + if log_file: + log_file.parent.mkdir(parents=True, exist_ok=True) + file_handler = logging.FileHandler(str(log_file)) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + return logger + + +# Default logger instance +default_logger = setup_logger("nilai") diff --git a/uv.lock b/uv.lock index 7e68ba3b..7ec2c16c 100644 --- a/uv.lock +++ b/uv.lock @@ -659,12 +659,14 @@ source = { editable = "packages/nilai-common" } dependencies = [ { name = "etcd3gw" }, { name = "pydantic" }, + { name = "tenacity" }, ] [package.metadata] requires-dist = [ { name = "etcd3gw", specifier = ">=2.4.2" }, { name = "pydantic", specifier = ">=2.10.1" }, + { name = "tenacity", specifier = ">=9.0.0" }, ] [[package]] @@ -1269,6 +1271,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b2/fe/81695a1aa331a842b582453b605175f419fe8540355886031328089d840a/sympy-1.13.1-py3-none-any.whl", hash = "sha256:db36cdc64bf61b9b24578b6f7bab1ecdd2452cf008f34faa33776680c26d66f8", size = 6189177 }, ] +[[package]] +name = "tenacity" +version = "9.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cd/94/91fccdb4b8110642462e653d5dcb27e7b674742ad68efd146367da7bdb10/tenacity-9.0.0.tar.gz", hash = "sha256:807f37ca97d62aa361264d497b0e31e92b8027044942bfa756160d908320d73b", size = 47421 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b6/cb/b86984bed139586d01532a587464b5805f12e397594f19f931c4c2fbfa61/tenacity-9.0.0-py3-none-any.whl", hash = "sha256:93de0c98785b27fcf659856aa9f54bfbd399e29969b0621bc7f762bd441b4539", size = 28169 }, +] + [[package]] name = "tokenizers" version = "0.20.3"