Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name: Python Tests

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main"] # Adjust branches as needed

jobs:
test:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4

- uses: astral-sh/setup-uv@v4
with:
enable-cache: true
cache-dependency-glob: "**/pyproject.toml"

- name: Cache dependencies
uses: actions/cache@v3
with:
path: ${{ env.UV_CACHE_DIR }}
key: ${{ runner.os }}-uv-${{ hashFiles('**/pyproject.toml') }}
restore-keys: |
${{ runner.os }}-uv-

- name: Install dependencies
run: |
uv sync

- name: Run Ruff format check
run: uv run ruff format --check

- name: Run Ruff linting
run: uv run ruff check

- name: Run tests
run: uv run pytest -v
8 changes: 4 additions & 4 deletions nilai-api/src/nilai_api/routers/private.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def get_usage(user: dict = Depends(get_user)) -> Usage:
"""
Retrieve the current token usage for the authenticated user.

- **user**: Authenticated user information (through X-API-Key header)
- **user**: Authenticated user information (through HTTP Bearer header)
- **Returns**: Usage statistics for the user's token consumption

### Example
Expand All @@ -47,7 +47,7 @@ async def get_attestation(user: dict = Depends(get_user)) -> AttestationResponse
"""
Generate a cryptographic attestation report.

- **user**: Authenticated user information (through X-API-Key header)
- **user**: Authenticated user information (through HTTP Bearer header)
- **Returns**: Attestation details for service verification

### Attestation Details
Expand All @@ -70,7 +70,7 @@ async def get_models(user: dict = Depends(get_user)) -> list[ModelMetadata]:
"""
List all available models in the system.

- **user**: Authenticated user information (through X-API-Key header)
- **user**: Authenticated user information (through HTTP Bearer header)
- **Returns**: Dictionary of available models

### Example
Expand Down Expand Up @@ -100,7 +100,7 @@ async def chat_completion(
Generate a chat completion response from the AI model.

- **req**: Chat completion request containing messages and model specifications
- **user**: Authenticated user information (through X-API-Key header)
- **user**: Authenticated user information (through HTTP Bearer header)
- **Returns**: Full chat response with model output, usage statistics, and cryptographic signature

### Request Requirements
Expand Down
198 changes: 91 additions & 107 deletions nilai-api/src/nilai_api/sev/sev.py
Original file line number Diff line number Diff line change
@@ -1,116 +1,100 @@
import base64
import ctypes
import os
from ctypes import c_char_p, c_int, create_string_buffer
import logging
from typing import Optional

logger = logging.getLogger(__name__)


class SEVGuest:
def __init__(self):
self.lib: Optional[ctypes.CDLL] = None
self._load_library()

def _load_library(self) -> None:
try:
lib_path = f"{os.path.dirname(os.path.abspath(__file__))}/libsevguest.so"
if not os.path.exists(lib_path):
logger.warning(f"SEV library not found at {lib_path}")
return

self.lib = ctypes.CDLL(lib_path)
self._setup_library_functions()
except Exception as e:
logger.warning(f"Failed to load SEV library: {e}")
self.lib = None

def _setup_library_functions(self) -> None:
if not self.lib:
return

self.lib.OpenDevice.restype = ctypes.c_int
self.lib.GetQuoteProvider.restype = ctypes.c_int
self.lib.Init.restype = ctypes.c_int
self.lib.GetQuote.restype = ctypes.c_char_p
self.lib.GetQuote.argtypes = [ctypes.c_char_p]
self.lib.VerifyQuote.restype = ctypes.c_int
self.lib.VerifyQuote.argtypes = [ctypes.c_char_p]
self.lib.free.argtypes = [ctypes.c_char_p]

def init(self) -> bool:
"""Initialize the device and quote provider."""
if not self.lib:
logger.warning("SEV library not loaded, running in mock mode")
return True
return self.lib.Init() == 0

def get_quote(self, report_data: Optional[bytes] = None) -> str:
"""Get a quote using the report data."""
if not self.lib:
logger.warning("SEV library not loaded, returning mock quote")
return base64.b64encode(b"mock_quote").decode("ascii")

if report_data is None:
report_data = bytes(64)

if len(report_data) != 64:
raise ValueError("Report data must be exactly 64 bytes")

report_buffer = ctypes.create_string_buffer(report_data)
quote_ptr = self.lib.GetQuote(report_buffer)

if quote_ptr is None:
raise RuntimeError("Failed to get quote")

quote_str = ctypes.string_at(quote_ptr)
return base64.b64encode(quote_str).decode("ascii")

def verify_quote(self, quote: str) -> bool:
"""Verify the quote using the library's verification method."""
if not self.lib:
logger.warning(
"SEV library not loaded, mock verification always returns True"
)
return True

quote_bytes = base64.b64decode(quote.encode("ascii"))
quote_buffer = ctypes.create_string_buffer(quote_bytes)
return self.lib.VerifyQuote(quote_buffer) == 0


# Global instance
sev = SEVGuest()

# Load the shared library
lib = ctypes.CDLL(f"{os.path.dirname(os.path.abspath(__file__))}/libsevguest.so")

# OpenDevice
lib.OpenDevice.restype = c_int

# GetQuoteProvider
lib.GetQuoteProvider.restype = c_int

# Init
lib.Init.restype = c_int

# GetQuote
lib.GetQuote.restype = c_char_p
lib.GetQuote.argtypes = [c_char_p]

# VerifyQuote
lib.VerifyQuote.restype = c_int
lib.VerifyQuote.argtypes = [c_char_p]

lib.free.argtypes = [c_char_p]


# Python wrapper functions
def init():
"""Initialize the device and quote provider."""
if lib.Init() != 0:
raise RuntimeError("Failed to initialize SEV guest device and quote provider.")


def get_quote(report_data=None) -> str:
"""
Get a quote using the report data.

Args:
report_data (bytes, optional): 64-byte report data.
Defaults to 64 zero bytes.

Returns:
str: The quote as a string
"""
# Use 64 zero bytes if no report data provided
if report_data is None:
report_data = bytes(64)

# Validate report data
if len(report_data) != 64:
raise ValueError("Report data must be exactly 64 bytes")

# Create a buffer from the report data
report_buffer = create_string_buffer(report_data)

# Get the quote
quote_ptr = lib.GetQuote(report_buffer)
quote_str = ctypes.string_at(quote_ptr)

# We should be freeing the quote, but it turns out it raises an error.
# lib.free(quote_ptr)
# Check if quote retrieval failed
if quote_ptr is None:
raise RuntimeError("Failed to get quote")

# Convert quote to Python string
quote = base64.b64encode(quote_str)
return quote.decode("ascii")


def verify_quote(quote: str) -> bool:
"""
Verify the quote using the library's verification method.

Args:
quote (str): The quote to verify

Returns:
bool: True if quote is verified, False otherwise
"""
# Ensure quote is a string
if not isinstance(quote, str):
quote = str(quote)

# Convert to bytes
quote_bytes = base64.b64decode(quote.encode("ascii"))
quote_buffer = create_string_buffer(quote_bytes)

# Verify quote
result = lib.VerifyQuote(quote_buffer)
return result == 0


# Example usage
if __name__ == "__main__":
try:
# Initialize the device and quote provider
init()
print("SEV guest device initialized successfully.")

# Create a 64-byte report data array (all zeros for simplicity)
report_data = bytes([0] * 64)

# Get the quote
quote = get_quote(report_data)
print(type(quote))
print("Quote:", quote)

if verify_quote(quote):
print("Quote verified successfully.")
if sev.init():
print("SEV guest device initialized successfully.")
report_data = bytes([0] * 64)
quote = sev.get_quote(report_data)
print("Quote:", quote)

if sev.verify_quote(quote):
print("Quote verified successfully.")
else:
print("Quote verification failed.")
else:
print("Quote verification failed.")
print("Failed to initialize SEV guest device.")
except Exception as e:
print("Error:", e)
10 changes: 5 additions & 5 deletions nilai-api/src/nilai_api/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from dotenv import load_dotenv
from nilai_api.crypto import generate_key_pair
from nilai_api.sev.sev import get_quote, init
from nilai_api.sev.sev import sev
from nilai_common import ModelServiceDiscovery, SETTINGS
from nilai_common.api_model import ModelEndpoint

Expand All @@ -22,21 +22,21 @@ def __init__(self):
)
self._uptime = time.time()
self._cpu_quote = None
self._gpu_quote = None
self._gpu_quote = "<No GPU>"

@property
def cpu_attestation(self) -> str:
if self._cpu_quote is None:
try:
init()
self._cpu_quote = get_quote()
sev.init()
self._cpu_quote = sev.get_quote()
except RuntimeError:
self._cpu_quote = "<Non TEE CPU>"
return self._cpu_quote

@property
def gpu_attestation(self) -> str:
return "<No GPU>"
return self._gpu_quote

@property
def uptime(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self):
repo_id="bartowski/Llama-3.2-1B-Instruct-GGUF",
filename="Llama-3.2-1B-Instruct-Q5_K_S.gguf",
n_threads=16,
n_ctx=2048,
verbose=False,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self):
repo_id="bartowski/Meta-Llama-3-8B-Instruct-GGUF",
filename="Meta-Llama-3-8B-Instruct-Q5_K_M.gguf",
n_threads=16,
n_ctx=2048,
verbose=False,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self):
repo_id="bartowski/Llama-3.2-1B-Instruct-GGUF",
filename="Llama-3.2-1B-Instruct-Q5_K_M.gguf",
n_threads=16,
n_ctx=2048,
verbose=False,
)

Expand Down Expand Up @@ -94,10 +95,12 @@ async def chat_completion(
for msg in req.messages
]

prompt += [{
"role": "system",
"content": "In addition to the previous. You are a cheese expert. You use cheese for all your answers. Whatever the user asks, you respond with a cheese-related answer or analogy.",
}]
prompt += [
{
"role": "system",
"content": "In addition to the previous. You are a cheese expert. You use cheese for all your answers. Whatever the user asks, you respond with a cheese-related answer or analogy.",
}
]
# Generate chat completion using the Llama model
# - Converts messages into a model-compatible prompt
# - type: ignore suppresses type checking for external library
Expand Down
4 changes: 3 additions & 1 deletion packages/nilai-common/src/nilai_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@
AttestationResponse,
ChatRequest,
ChatResponse,
Choice,
HealthCheckResponse,
Message,
ModelEndpoint,
ModelMetadata,
Usage,
)
from nilai_common.config import SETTINGS
from nilai_common.db import ModelServiceDiscovery
from nilai_common.discovery import ModelServiceDiscovery

__all__ = [
"Message",
"ChatRequest",
"ChatResponse",
"Choice",
"ModelMetadata",
"Usage",
"AttestationResponse",
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ dependencies = [
dev = [
"black>=24.10.0",
"isort>=5.13.2",
"pytest-mock>=3.14.0",
"pytest>=8.3.3",
"ruff>=0.8.0",
"uvicorn>=0.32.1",
"pytest-asyncio>=0.25.0",
]

[build-system]
Expand Down
Loading
Loading