Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 46 additions & 9 deletions nilai-models/src/nilai_models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,52 @@ def __init__(self, metadata: ModelMetadata, prefix="/models"):
def setup_app(self):
@asynccontextmanager
async def lifespan(app: FastAPI):
# Load the model on the API
discovery_service = ModelServiceDiscovery(
host=SETTINGS["etcd_host"], port=SETTINGS["etcd_port"]
)
lease = await discovery_service.register_model(self.endpoint, self.prefix)
asyncio.create_task(discovery_service.keep_alive(lease))
logger.info(f"Registered model endpoint: {self.endpoint}")
yield
await discovery_service.unregister_model(self.endpoint.metadata.id)
discovery_service = None
keep_alive_task = None

try:
# Initialize discovery service
discovery_service = ModelServiceDiscovery(
host=SETTINGS["etcd_host"], port=SETTINGS["etcd_port"]
)

# Validate model metadata
if not self.endpoint or not self.endpoint.metadata:
raise ValueError("Invalid model metadata")

# Register model and start keepalive
logger.info(f"Registering model: {self.endpoint.metadata.id}")
lease = await discovery_service.register_model(
self.endpoint, self.prefix
)
keep_alive_task = asyncio.create_task(
discovery_service.keep_alive(lease),
name=f"keepalive_{self.endpoint.metadata.id}",
)

logger.info(f"Model registered successfully: {self.endpoint}")
yield

except Exception as e:
logger.error(f"Failed to initialize model service: {e}")
raise
finally:
# Cleanup
if keep_alive_task and not keep_alive_task.done():
keep_alive_task.cancel()
try:
await keep_alive_task
except asyncio.CancelledError:
pass

if discovery_service:
try:
await discovery_service.unregister_model(
self.endpoint.metadata.id
)
logger.info(f"Model unregistered: {self.endpoint.metadata.id}")
except Exception as e:
logger.error(f"Error unregistering model: {e}")

# Create a FastAPI application instance for the model
self.app = FastAPI(lifespan=lifespan)
Expand Down
1 change: 1 addition & 0 deletions packages/nilai-common/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ requires-python = ">=3.12"
dependencies = [
"etcd3gw>=2.4.2",
"pydantic>=2.10.1",
"tenacity>=9.0.0",
]

[build-system]
Expand Down
87 changes: 64 additions & 23 deletions packages/nilai-common/src/nilai_common/discovery.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import asyncio
import logging
from typing import Dict, Optional

from asyncio import CancelledError
from datetime import datetime
from tenacity import retry, wait_exponential, stop_after_attempt


from etcd3gw import Lease
from etcd3gw.client import Etcd3Client
from nilai_common.api_model import ModelEndpoint, ModelMetadata

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class ModelServiceDiscovery:
def __init__(self, host: str = "localhost", port: int = 2379, lease_ttl: int = 60):
Expand All @@ -15,8 +25,22 @@ def __init__(self, host: str = "localhost", port: int = 2379, lease_ttl: int = 6
:param port: etcd server port
:param lease_ttl: Lease time for endpoint registration (in seconds)
"""
self.client = Etcd3Client(host=host, port=port)
self.host = host
self.port = port
self.lease_ttl = lease_ttl
self.initialize()

self.is_healthy = True
self.last_refresh = None
self.max_retries = 3
self.base_delay = 1
self._shutdown = False

def initialize(self):
"""
Initialize the etcd client.
"""
self.client = Etcd3Client(host=self.host, port=self.port)

async def register_model(
self, model_endpoint: ModelEndpoint, prefix: str = "/models"
Expand Down Expand Up @@ -73,7 +97,7 @@ async def discover_models(

discovered_models[model_endpoint.metadata.id] = model_endpoint
except Exception as e:
print(f"Error parsing model endpoint: {e}")
logger.error(f"Error parsing model endpoint: {e}")
return discovered_models

async def get_model(
Expand Down Expand Up @@ -101,19 +125,36 @@ async def unregister_model(self, model_id: str):
key = f"/models/{model_id}"
self.client.delete(key)

async def keep_alive(self, lease: Lease):
"""
Keep the model registration lease alive.

:param lease_id: Lease ID to keep alive
"""
while True:
try:
lease.refresh()
await asyncio.sleep(self.lease_ttl // 2)
except Exception as e:
print(f"Lease keepalive failed: {e}")
break
@retry(
wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3)
)
async def _refresh_lease(self, lease):
lease.refresh()
self.last_refresh = datetime.now()
self.is_healthy = True

async def keep_alive(self, lease):
"""Keep the model registration lease alive with graceful shutdown."""
try:
while not self._shutdown:
try:
await self._refresh_lease(lease)
await asyncio.sleep(self.lease_ttl // 2)
except Exception as e:
self.is_healthy = False
logger.error(f"Lease keepalive failed: {e}")
try:
self.initialize()
lease.client = self.client
except Exception as init_error:
logger.error(f"Reinitialization failed: {init_error}")
await asyncio.sleep(self.base_delay)
except CancelledError:
logger.info("Lease keepalive task cancelled, shutting down...")
self._shutdown = True
raise
finally:
self.is_healthy = False


# Example usage
Expand Down Expand Up @@ -146,23 +187,23 @@ async def main():
discovered_models = await service_discovery.discover_models(
name="Image Classification", feature="image_classification"
)
print("FOUND: ", len(discovered_models))
logger.info(f"FOUND: {len(discovered_models)}")
for model in discovered_models.values():
print(f"Discovered Model: {model.metadata.id}")
print(f"URL: {model.url}")
print(f"Supported Features: {model.metadata.supported_features}")
logger.info(f"Discovered Model: {model.metadata.id}")
logger.info(f"URL: {model.url}")
logger.info(f"Supported Features: {model.metadata.supported_features}")

# Optional: Keep the service running
await asyncio.sleep(10) # Keep running for an hour
# Discover models (with optional filtering)
discovered_models = await service_discovery.discover_models(
name="Image Classification", feature="image_classification"
)
print("FOUND: ", len(discovered_models))
logger.info(f"FOUND: {len(discovered_models)}")
for model in discovered_models.values():
print(f"Discovered Model: {model.metadata.id}")
print(f"URL: {model.url}")
print(f"Supported Features: {model.metadata.supported_features}")
logger.info(f"Discovered Model: {model.metadata.id}")
logger.info(f"URL: {model.url}")
logger.info(f"Supported Features: {model.metadata.supported_features}")

# Cleanup
await service_discovery.unregister_model(model_endpoint.metadata.id)
Expand Down
39 changes: 39 additions & 0 deletions packages/nilai-common/src/nilai_common/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import logging
import sys
from typing import Optional
from pathlib import Path


def setup_logger(
name: str,
level: int = logging.INFO,
log_file: Optional[Path] = None,
) -> logging.Logger:
"""Configure common logger for Nilai services."""

# Create logger with service name
logger = logging.getLogger(name)
logger.setLevel(level)

# Create formatter
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)

# Console handler
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

# File handler if path provided
if log_file:
log_file.parent.mkdir(parents=True, exist_ok=True)
file_handler = logging.FileHandler(str(log_file))
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)

return logger


# Default logger instance
default_logger = setup_logger("nilai")
11 changes: 11 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading