Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ services:
context: .
dockerfile: docker/api.Dockerfile
target: nilai
depends_on:
- etcd
volumes:
- ${PWD}/db/:/app/db/ # sqlite database for users
environment:
Expand All @@ -28,6 +30,8 @@ services:
target: nilai
args:
MODEL_NAME: "llama_1b_cpu"
depends_on:
- etcd
environment:
- SVC_HOST=llama_1b_cpu
- SVC_PORT=8000
Expand All @@ -44,6 +48,8 @@ services:
target: nilai
args:
MODEL_NAME: "llama_8b_cpu"
depends_on:
- etcd
environment:
- SVC_HOST=llama_8b_cpu
- SVC_PORT=8000
Expand All @@ -53,6 +59,24 @@ services:
- hugging_face_models:/root/.cache/huggingface # cache models
networks:
- backend_net
secret_llama_1b_cpu:
build:
context: .
dockerfile: docker/model.Dockerfile
target: nilai
args:
MODEL_NAME: "secret_llama_1b_cpu"
depends_on:
- etcd
environment:
- SVC_HOST=secret_llama_1b_cpu
- SVC_PORT=8000
- ETCD_HOST=etcd
- ETCD_PORT=2379
volumes:
- hugging_face_models:/root/.cache/huggingface # cache models
networks:
- backend_net
volumes:
hugging_face_models:

Expand Down
2 changes: 1 addition & 1 deletion nilai-api/src/nilai_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,4 @@


app.include_router(public.router)
app.include_router(private.router, dependencies=[Depends(get_user)])
app.include_router(private.router, dependencies=[Depends(get_user)])
8 changes: 4 additions & 4 deletions nilai-api/src/nilai_api/routers/private.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,13 @@ async def chat_completion(
"""

model_name = req.model
models = await state.models
if model_name not in models:
endpoint = await state.get_model(model_name)
if endpoint is None:
raise HTTPException(
status_code=400, detail=f"Invalid model name: {models.keys()}"
status_code=400, detail="Invalid model name, check /v1/models for options"
)

model_url = models[model_name].url
model_url = endpoint.url

try:
async with httpx.AsyncClient() as client:
Expand Down
9 changes: 7 additions & 2 deletions nilai-api/src/nilai_api/state.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import time
from asyncio import Semaphore
from typing import Dict
from typing import Dict, Optional

from dotenv import load_dotenv
from nilai_api.crypto import generate_key_pair
Expand All @@ -17,7 +17,9 @@ def __init__(self):
self.private_key, self.public_key, self.verifying_key = generate_key_pair()
self.sem = Semaphore(2)

self.discovery_service = ModelServiceDiscovery(host=SETTINGS["etcd_host"], port=SETTINGS["etcd_port"])
self.discovery_service = ModelServiceDiscovery(
host=SETTINGS["etcd_host"], port=SETTINGS["etcd_port"]
)
self._uptime = time.time()
self._cpu_quote = None
self._gpu_quote = None
Expand Down Expand Up @@ -59,6 +61,9 @@ def uptime(self):
async def models(self) -> Dict[str, ModelEndpoint]:
return await self.discovery_service.discover_models()

async def get_model(self, model_id: str) -> Optional[ModelEndpoint]:
return await self.discovery_service.get_model(model_id)


load_dotenv()
state = AppState()
11 changes: 8 additions & 3 deletions nilai-models/src/nilai_models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
from nilai_common import HealthCheckResponse # Custom response type for health checks
from nilai_common import ModelEndpoint # Endpoint information for model registration
from nilai_common import ModelMetadata # Metadata about the model
from nilai_common import SETTINGS, ModelServiceDiscovery # Model service discovery and host settings
from nilai_common import (
SETTINGS,
ModelServiceDiscovery,
) # Model service discovery and host settings

logger = logging.getLogger(__name__)

Expand All @@ -25,12 +28,13 @@ class Model(ABC):
implementations with consistent API endpoints and behaviors.
"""

def __init__(self, metadata: ModelMetadata):
def __init__(self, metadata: ModelMetadata, prefix="/models"):
"""
Initialize the model with its metadata and tracking information.

Args:
metadata (ModelMetadata): Detailed information about the model.
prefix (str): Optional prefix for hiding the model behind a subpath.
"""
# Store the model's metadata for later retrieval
self.metadata = metadata
Expand All @@ -39,6 +43,7 @@ def __init__(self, metadata: ModelMetadata):
# Record the start time for uptime tracking
self._uptime = time.time()
self.app = self.setup_app()
self.prefix = prefix

def setup_app(self):
@asynccontextmanager
Expand All @@ -47,7 +52,7 @@ async def lifespan(app: FastAPI):
discovery_service = ModelServiceDiscovery(
host=SETTINGS["etcd_host"], port=SETTINGS["etcd_port"]
)
lease = await discovery_service.register_model(self.endpoint)
lease = await discovery_service.register_model(self.endpoint, self.prefix)
asyncio.create_task(discovery_service.keep_alive(lease))
logger.info(f"Registered model endpoint: {self.endpoint}")
yield
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from nilai_models.models.secret_llama_1b_cpu.secret_llama_1b_cpu import app

__all__ = ["app"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from fastapi import HTTPException
from llama_cpp import Llama
from nilai_common import ChatRequest, ChatResponse, Message, ModelMetadata
from nilai_models.model import Model


class SecretLlama1BCpu(Model):
"""
A specific implementation of the Model base class for the Llama 1B CPU model.

This class provides:
- Model initialization using llama_cpp
- Chat completion functionality
- Metadata about the Llama model
"""

def __init__(self):
"""
Initialize the Llama 1B model:
1. Load the pre-trained model using llama_cpp
2. Set up model metadata

Configuration details:
- Uses a specific quantized model from Hugging Face
- Configured for CPU inference
- Uses 16 threads for improved performance on CPU
"""
# Load the pre-trained Llama model
# - repo_id: Source of the model
# - filename: Specific model file (quantized version)
# - n_threads: Number of CPU threads for inference
# - verbose: Disable detailed logging
self.model = Llama.from_pretrained(
repo_id="bartowski/Llama-3.2-1B-Instruct-GGUF",
filename="Llama-3.2-1B-Instruct-Q5_K_M.gguf",
n_threads=16,
verbose=False,
)

# Initialize the base Model class with model metadata
# Provides comprehensive information about the model
super().__init__(
ModelMetadata(
id="bartowski/Llama-3.2-1B-Instruct-GGUF", # Unique identifier
name="CheesyLlama", # Human-readable name
version="1.0", # Model version
description="Llama is a large language model trained on supervised and unsupervised data.",
author="Meta-Llama", # Model creators
license="Apache 2.0", # Usage license
source="https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF", # Model source
supported_features=["chat_completion"], # Capabilities
),
prefix="d01fe399-8dc2-4c74-acde-ff649802f437",
)

async def chat_completion(
self,
req: ChatRequest = ChatRequest(
# Default request with sample messages for documentation
model="bartowski/Llama-3.2-1B-Instruct-GGUF",
messages=[
Message(role="system", content="You are a helpful assistant."),
Message(role="user", content="What is your name?"),
],
),
) -> ChatResponse:
"""
Generate a chat completion using the Llama model.

Args:
req (ChatRequest): The chat request containing conversation messages.

Returns:
ChatResponse: The model's generated response.

Raises:
ValueError: If the model fails to generate a response.
"""
if not req.messages or len(req.messages) == 0:
raise HTTPException(
status_code=400, detail="The 'messages' field is required."
)
if not req.model:
raise HTTPException(
status_code=400, detail="The 'model' field is required."
)
# Transform incoming messages into a format compatible with llama_cpp
# Extracts role and content from each message
prompt = [
{
"role": msg.role, # Preserve message role (system/user/assistant)
"content": msg.content, # Preserve message content
}
for msg in req.messages
]

prompt += [{
"role": "system",
"content": "In addition to the previous. You are a cheese expert. You use cheese for all your answers. Whatever the user asks, you respond with a cheese-related answer or analogy.",
}]
# Generate chat completion using the Llama model
# - Converts messages into a model-compatible prompt
# - type: ignore suppresses type checking for external library
generation: dict = self.model.create_chat_completion(prompt) # type: ignore

# Validate model output
if not generation or len(generation) == 0:
raise ValueError("The model returned no output.")

# Convert model generation to ChatResponse
# - Uses dictionary unpacking to convert generation results
# - Signature left empty (can be extended for tracking/verification)
response = ChatResponse(
signature="",
**generation,
)
response.model = self.metadata.name # Set model identifier
return response


# Create and expose the FastAPI app for this Llama model
# - Calls get_app() from the base Model class
# - Allows easy integration with ASGI servers like uvicorn
app = SecretLlama1BCpu().get_app()
25 changes: 21 additions & 4 deletions packages/nilai-common/src/nilai_common/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, host: str = "localhost", port: int = 2379, lease_ttl: int = 6
self.lease_ttl = lease_ttl

async def register_model(
self, model_endpoint: ModelEndpoint, prefix: str = ""
self, model_endpoint: ModelEndpoint, prefix: str = "/models"
) -> Lease:
"""
Register a model endpoint in etcd.
Expand All @@ -31,7 +31,7 @@ async def register_model(
lease = self.client.lease(self.lease_ttl)

# Prepare the key and value
key = f"{prefix}/models/{model_endpoint.metadata.id}"
key = f"{prefix}/{model_endpoint.metadata.name}"
value = model_endpoint.model_dump_json()

# Put the key-value pair with the lease
Expand All @@ -43,7 +43,7 @@ async def discover_models(
self,
name: Optional[str] = None,
feature: Optional[str] = None,
prefix: Optional[str] = "",
prefix: Optional[str] = "/models",
) -> Dict[str, ModelEndpoint]:
"""
Discover models based on optional filters.
Expand All @@ -53,7 +53,8 @@ async def discover_models(
:return: List of matching ModelEndpoints
"""
# Get all model keys
model_range = self.client.get_prefix(f"{prefix}/models/")
model_range = self.client.get_prefix(f"{prefix}/")
self.client.get_prefix

discovered_models: Dict[str, ModelEndpoint] = {}
for resp, other in model_range:
Expand All @@ -76,6 +77,22 @@ async def discover_models(

return discovered_models

async def get_model(
self, model_id: str, prefix: str = "/models"
) -> Optional[ModelEndpoint]:
"""
Get a model endpoint by ID.

:param model_id: ID of the model to retrieve
:return: ModelEndpoint if found, None otherwise
"""
key = f"{prefix}/{model_id}"
value = self.client.get(key)
value = self.client.get(model_id) if not value else value
if value:
return ModelEndpoint.model_validate_json(value[0].decode("utf-8")) # type: ignore
return None

async def unregister_model(self, model_id: str):
"""
Unregister a model from service discovery.
Expand Down
Loading