diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..1ddbc6a3 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,40 @@ +name: Python Tests + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main"] # Adjust branches as needed + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - uses: astral-sh/setup-uv@v4 + with: + enable-cache: true + cache-dependency-glob: "**/pyproject.toml" + + - name: Cache dependencies + uses: actions/cache@v3 + with: + path: ${{ env.UV_CACHE_DIR }} + key: ${{ runner.os }}-uv-${{ hashFiles('**/pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-uv- + + - name: Install dependencies + run: | + uv sync + + - name: Run Ruff format check + run: uv run ruff format --check + + - name: Run Ruff linting + run: uv run ruff check + + - name: Run tests + run: uv run pytest -v \ No newline at end of file diff --git a/nilai-api/src/nilai_api/routers/private.py b/nilai-api/src/nilai_api/routers/private.py index 1923967c..0f6fb47c 100644 --- a/nilai-api/src/nilai_api/routers/private.py +++ b/nilai-api/src/nilai_api/routers/private.py @@ -30,7 +30,7 @@ async def get_usage(user: dict = Depends(get_user)) -> Usage: """ Retrieve the current token usage for the authenticated user. - - **user**: Authenticated user information (through X-API-Key header) + - **user**: Authenticated user information (through HTTP Bearer header) - **Returns**: Usage statistics for the user's token consumption ### Example @@ -47,7 +47,7 @@ async def get_attestation(user: dict = Depends(get_user)) -> AttestationResponse """ Generate a cryptographic attestation report. - - **user**: Authenticated user information (through X-API-Key header) + - **user**: Authenticated user information (through HTTP Bearer header) - **Returns**: Attestation details for service verification ### Attestation Details @@ -70,7 +70,7 @@ async def get_models(user: dict = Depends(get_user)) -> list[ModelMetadata]: """ List all available models in the system. - - **user**: Authenticated user information (through X-API-Key header) + - **user**: Authenticated user information (through HTTP Bearer header) - **Returns**: Dictionary of available models ### Example @@ -100,7 +100,7 @@ async def chat_completion( Generate a chat completion response from the AI model. - **req**: Chat completion request containing messages and model specifications - - **user**: Authenticated user information (through X-API-Key header) + - **user**: Authenticated user information (through HTTP Bearer header) - **Returns**: Full chat response with model output, usage statistics, and cryptographic signature ### Request Requirements diff --git a/nilai-api/src/nilai_api/sev/sev.py b/nilai-api/src/nilai_api/sev/sev.py index d688bc54..c9090b00 100644 --- a/nilai-api/src/nilai_api/sev/sev.py +++ b/nilai-api/src/nilai_api/sev/sev.py @@ -1,116 +1,100 @@ import base64 import ctypes import os -from ctypes import c_char_p, c_int, create_string_buffer +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + + +class SEVGuest: + def __init__(self): + self.lib: Optional[ctypes.CDLL] = None + self._load_library() + + def _load_library(self) -> None: + try: + lib_path = f"{os.path.dirname(os.path.abspath(__file__))}/libsevguest.so" + if not os.path.exists(lib_path): + logger.warning(f"SEV library not found at {lib_path}") + return + + self.lib = ctypes.CDLL(lib_path) + self._setup_library_functions() + except Exception as e: + logger.warning(f"Failed to load SEV library: {e}") + self.lib = None + + def _setup_library_functions(self) -> None: + if not self.lib: + return + + self.lib.OpenDevice.restype = ctypes.c_int + self.lib.GetQuoteProvider.restype = ctypes.c_int + self.lib.Init.restype = ctypes.c_int + self.lib.GetQuote.restype = ctypes.c_char_p + self.lib.GetQuote.argtypes = [ctypes.c_char_p] + self.lib.VerifyQuote.restype = ctypes.c_int + self.lib.VerifyQuote.argtypes = [ctypes.c_char_p] + self.lib.free.argtypes = [ctypes.c_char_p] + + def init(self) -> bool: + """Initialize the device and quote provider.""" + if not self.lib: + logger.warning("SEV library not loaded, running in mock mode") + return True + return self.lib.Init() == 0 + + def get_quote(self, report_data: Optional[bytes] = None) -> str: + """Get a quote using the report data.""" + if not self.lib: + logger.warning("SEV library not loaded, returning mock quote") + return base64.b64encode(b"mock_quote").decode("ascii") + + if report_data is None: + report_data = bytes(64) + + if len(report_data) != 64: + raise ValueError("Report data must be exactly 64 bytes") + + report_buffer = ctypes.create_string_buffer(report_data) + quote_ptr = self.lib.GetQuote(report_buffer) + + if quote_ptr is None: + raise RuntimeError("Failed to get quote") + + quote_str = ctypes.string_at(quote_ptr) + return base64.b64encode(quote_str).decode("ascii") + + def verify_quote(self, quote: str) -> bool: + """Verify the quote using the library's verification method.""" + if not self.lib: + logger.warning( + "SEV library not loaded, mock verification always returns True" + ) + return True + + quote_bytes = base64.b64decode(quote.encode("ascii")) + quote_buffer = ctypes.create_string_buffer(quote_bytes) + return self.lib.VerifyQuote(quote_buffer) == 0 + + +# Global instance +sev = SEVGuest() -# Load the shared library -lib = ctypes.CDLL(f"{os.path.dirname(os.path.abspath(__file__))}/libsevguest.so") - -# OpenDevice -lib.OpenDevice.restype = c_int - -# GetQuoteProvider -lib.GetQuoteProvider.restype = c_int - -# Init -lib.Init.restype = c_int - -# GetQuote -lib.GetQuote.restype = c_char_p -lib.GetQuote.argtypes = [c_char_p] - -# VerifyQuote -lib.VerifyQuote.restype = c_int -lib.VerifyQuote.argtypes = [c_char_p] - -lib.free.argtypes = [c_char_p] - - -# Python wrapper functions -def init(): - """Initialize the device and quote provider.""" - if lib.Init() != 0: - raise RuntimeError("Failed to initialize SEV guest device and quote provider.") - - -def get_quote(report_data=None) -> str: - """ - Get a quote using the report data. - - Args: - report_data (bytes, optional): 64-byte report data. - Defaults to 64 zero bytes. - - Returns: - str: The quote as a string - """ - # Use 64 zero bytes if no report data provided - if report_data is None: - report_data = bytes(64) - - # Validate report data - if len(report_data) != 64: - raise ValueError("Report data must be exactly 64 bytes") - - # Create a buffer from the report data - report_buffer = create_string_buffer(report_data) - - # Get the quote - quote_ptr = lib.GetQuote(report_buffer) - quote_str = ctypes.string_at(quote_ptr) - - # We should be freeing the quote, but it turns out it raises an error. - # lib.free(quote_ptr) - # Check if quote retrieval failed - if quote_ptr is None: - raise RuntimeError("Failed to get quote") - - # Convert quote to Python string - quote = base64.b64encode(quote_str) - return quote.decode("ascii") - - -def verify_quote(quote: str) -> bool: - """ - Verify the quote using the library's verification method. - - Args: - quote (str): The quote to verify - - Returns: - bool: True if quote is verified, False otherwise - """ - # Ensure quote is a string - if not isinstance(quote, str): - quote = str(quote) - - # Convert to bytes - quote_bytes = base64.b64decode(quote.encode("ascii")) - quote_buffer = create_string_buffer(quote_bytes) - - # Verify quote - result = lib.VerifyQuote(quote_buffer) - return result == 0 - - -# Example usage if __name__ == "__main__": try: - # Initialize the device and quote provider - init() - print("SEV guest device initialized successfully.") - - # Create a 64-byte report data array (all zeros for simplicity) - report_data = bytes([0] * 64) - - # Get the quote - quote = get_quote(report_data) - print(type(quote)) - print("Quote:", quote) - - if verify_quote(quote): - print("Quote verified successfully.") + if sev.init(): + print("SEV guest device initialized successfully.") + report_data = bytes([0] * 64) + quote = sev.get_quote(report_data) + print("Quote:", quote) + + if sev.verify_quote(quote): + print("Quote verified successfully.") + else: + print("Quote verification failed.") else: - print("Quote verification failed.") + print("Failed to initialize SEV guest device.") except Exception as e: print("Error:", e) diff --git a/nilai-api/src/nilai_api/state.py b/nilai-api/src/nilai_api/state.py index d023aa84..e1680a92 100644 --- a/nilai-api/src/nilai_api/state.py +++ b/nilai-api/src/nilai_api/state.py @@ -5,7 +5,7 @@ from dotenv import load_dotenv from nilai_api.crypto import generate_key_pair -from nilai_api.sev.sev import get_quote, init +from nilai_api.sev.sev import sev from nilai_common import ModelServiceDiscovery, SETTINGS from nilai_common.api_model import ModelEndpoint @@ -22,21 +22,21 @@ def __init__(self): ) self._uptime = time.time() self._cpu_quote = None - self._gpu_quote = None + self._gpu_quote = "" @property def cpu_attestation(self) -> str: if self._cpu_quote is None: try: - init() - self._cpu_quote = get_quote() + sev.init() + self._cpu_quote = sev.get_quote() except RuntimeError: self._cpu_quote = "" return self._cpu_quote @property def gpu_attestation(self) -> str: - return "" + return self._gpu_quote @property def uptime(self): diff --git a/nilai-models/src/nilai_models/models/llama_1b_cpu/llama_1b_cpu.py b/nilai-models/src/nilai_models/models/llama_1b_cpu/llama_1b_cpu.py index e133a0b1..f11c8e6a 100644 --- a/nilai-models/src/nilai_models/models/llama_1b_cpu/llama_1b_cpu.py +++ b/nilai-models/src/nilai_models/models/llama_1b_cpu/llama_1b_cpu.py @@ -34,6 +34,7 @@ def __init__(self): repo_id="bartowski/Llama-3.2-1B-Instruct-GGUF", filename="Llama-3.2-1B-Instruct-Q5_K_S.gguf", n_threads=16, + n_ctx=2048, verbose=False, ) diff --git a/nilai-models/src/nilai_models/models/llama_8b_cpu/llama_8b_cpu.py b/nilai-models/src/nilai_models/models/llama_8b_cpu/llama_8b_cpu.py index 57026944..bb9802bd 100644 --- a/nilai-models/src/nilai_models/models/llama_8b_cpu/llama_8b_cpu.py +++ b/nilai-models/src/nilai_models/models/llama_8b_cpu/llama_8b_cpu.py @@ -34,6 +34,7 @@ def __init__(self): repo_id="bartowski/Meta-Llama-3-8B-Instruct-GGUF", filename="Meta-Llama-3-8B-Instruct-Q5_K_M.gguf", n_threads=16, + n_ctx=2048, verbose=False, ) diff --git a/nilai-models/src/nilai_models/models/secret_llama_1b_cpu/secret_llama_1b_cpu.py b/nilai-models/src/nilai_models/models/secret_llama_1b_cpu/secret_llama_1b_cpu.py index 70f4d43e..8084caca 100644 --- a/nilai-models/src/nilai_models/models/secret_llama_1b_cpu/secret_llama_1b_cpu.py +++ b/nilai-models/src/nilai_models/models/secret_llama_1b_cpu/secret_llama_1b_cpu.py @@ -34,6 +34,7 @@ def __init__(self): repo_id="bartowski/Llama-3.2-1B-Instruct-GGUF", filename="Llama-3.2-1B-Instruct-Q5_K_M.gguf", n_threads=16, + n_ctx=2048, verbose=False, ) @@ -94,10 +95,12 @@ async def chat_completion( for msg in req.messages ] - prompt += [{ - "role": "system", - "content": "In addition to the previous. You are a cheese expert. You use cheese for all your answers. Whatever the user asks, you respond with a cheese-related answer or analogy.", - }] + prompt += [ + { + "role": "system", + "content": "In addition to the previous. You are a cheese expert. You use cheese for all your answers. Whatever the user asks, you respond with a cheese-related answer or analogy.", + } + ] # Generate chat completion using the Llama model # - Converts messages into a model-compatible prompt # - type: ignore suppresses type checking for external library diff --git a/packages/nilai-common/src/nilai_common/__init__.py b/packages/nilai-common/src/nilai_common/__init__.py index 4cf49984..44e39f63 100644 --- a/packages/nilai-common/src/nilai_common/__init__.py +++ b/packages/nilai-common/src/nilai_common/__init__.py @@ -2,6 +2,7 @@ AttestationResponse, ChatRequest, ChatResponse, + Choice, HealthCheckResponse, Message, ModelEndpoint, @@ -9,12 +10,13 @@ Usage, ) from nilai_common.config import SETTINGS -from nilai_common.db import ModelServiceDiscovery +from nilai_common.discovery import ModelServiceDiscovery __all__ = [ "Message", "ChatRequest", "ChatResponse", + "Choice", "ModelMetadata", "Usage", "AttestationResponse", diff --git a/packages/nilai-common/src/nilai_common/db.py b/packages/nilai-common/src/nilai_common/discovery.py similarity index 100% rename from packages/nilai-common/src/nilai_common/db.py rename to packages/nilai-common/src/nilai_common/discovery.py diff --git a/pyproject.toml b/pyproject.toml index dd835824..f49adc47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,9 +17,11 @@ dependencies = [ dev = [ "black>=24.10.0", "isort>=5.13.2", + "pytest-mock>=3.14.0", "pytest>=8.3.3", "ruff>=0.8.0", "uvicorn>=0.32.1", + "pytest-asyncio>=0.25.0", ] [build-system] diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..694d9aba 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,40 @@ +from nilai_common import ( + ChatResponse, + Choice, + Message, + ModelEndpoint, + ModelMetadata, + Usage, +) + +model_metadata: ModelMetadata = ModelMetadata( + id="ABC", # Unique identifier + name="ABC", # Human-readable name + version="1.0", # Model version + description="Description", + author="Author", # Model creators + license="License", # Usage license + source="http://test-model-url", # Model source + supported_features=["supported_feature"], # Capabilities +) + +model_endpoint: ModelEndpoint = ModelEndpoint( + url="http://test-model-url", metadata=model_metadata +) + +response: ChatResponse = ChatResponse( + id="test-id", + object="test-object", + model="test-model", + created=123456, + choices=[ + Choice( + index=0, + message=Message(role="test-role", content="test-content"), + finish_reason="test-finish-reason", + logprobs={"test-logprobs": "test-value"}, + ) + ], + usage=Usage(prompt_tokens=100, completion_tokens=50, total_tokens=150), + signature="test-signature", +) diff --git a/tests/model_execution_0.py b/tests/model_execution_0.py deleted file mode 100644 index f7bbe1ae..00000000 --- a/tests/model_execution_0.py +++ /dev/null @@ -1,38 +0,0 @@ -import os -import time - -import torch -from dotenv import load_dotenv -from transformers import pipeline - -# Load the .env file -load_dotenv() - -# # Application State Initialization -torch.set_num_threads(32) -torch.set_num_interop_threads(32) - - -chat_pipeline = pipeline( - "text-generation", - model="meta-llama/Llama-3.2-1B-Instruct", - model_kwargs={"torch_dtype": torch.bfloat16}, - device_map="cpu", - token=os.getenv("HUGGINGFACE_API_TOKEN"), -) - -messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is your name?"}, -] - -start = time.time() -# Generate response -generated = chat_pipeline( - messages, max_length=1024, num_return_sequences=1, truncation=True -) # type: ignore - -end = time.time() - -print(generated) -print(end - start) diff --git a/tests/model_execution_1.py b/tests/model_execution_1.py deleted file mode 100644 index 2ef17ab4..00000000 --- a/tests/model_execution_1.py +++ /dev/null @@ -1,49 +0,0 @@ -import time - -from onnxruntime import InferenceSession -from optimum.onnxruntime import ORTModelForCausalLM -from transformers import AutoTokenizer - -# Define the model directory and ONNX export location -model_name = "meta-llama/Llama-3.2-1B-Instruct" -onnx_export_dir = "./onnx_model" - -# Export the model -model = ORTModelForCausalLM.from_pretrained(model_name, from_transformers=True) -model.save_pretrained(onnx_export_dir) - -# Save the tokenizer for later use -tokenizer = AutoTokenizer.from_pretrained(model_name) -tokenizer.save_pretrained(onnx_export_dir) - - -# Load the ONNX model and tokenizer -onnx_model_path = "./onnx_model/model.onnx" -tokenizer = AutoTokenizer.from_pretrained("./onnx_model") - -# Create an ONNX Runtime session -session = InferenceSession(onnx_model_path) - -# Input messages -messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is your name?"}, -] - -# Prepare input text -input_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) - -# Tokenize input text -inputs = tokenizer(input_text, return_tensors="pt") -print("START:") -# Run inference -start = time.time() -onnx_inputs = {session.get_inputs()[0].name: inputs["input_ids"].numpy()} -onnx_output = session.run(None, onnx_inputs) -end = time.time() - -# Decode the output -output_text = tokenizer.decode(onnx_output[0][0], skip_special_tokens=True) - -print(output_text) -print(f"Time taken: {end - start} seconds") diff --git a/tests/model_execution_2.py b/tests/model_execution_2.py deleted file mode 100644 index 821279ae..00000000 --- a/tests/model_execution_2.py +++ /dev/null @@ -1,34 +0,0 @@ -import os -import time - -from dotenv import load_dotenv -from optimum.pipelines import pipeline - -# Load the .env file -load_dotenv() - - -chat_pipeline = pipeline( - "text-generation", - model="meta-llama/Llama-3.2-1B-Instruct", - accelerator="ort", - token=os.getenv("HUGGINGFACE_API_TOKEN"), -) - -messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is your name?"}, -] - -print("start") -for i in range(10): - start = time.time() - # Generate response - generated = chat_pipeline( - messages, max_length=1024, num_return_sequences=1, truncation=True - ) # type: ignore - - end = time.time() - - print(generated) - print(end - start) diff --git a/tests/model_execution_3.py b/tests/model_execution_3.py deleted file mode 100644 index 33c4ace7..00000000 --- a/tests/model_execution_3.py +++ /dev/null @@ -1,25 +0,0 @@ -import time - -from llama_cpp import Llama - -llm = Llama.from_pretrained( - repo_id="bartowski/Llama-3.2-1B-Instruct-GGUF", - filename="Llama-3.2-1B-Instruct-Q5_K_S.gguf", -) - - -messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is your name?"}, -] - -print("start") -for i in range(10): - start = time.time() - # Generate response - generated = llm.create_chat_completion(messages) - - end = time.time() - - print(generated) - print(end - start) diff --git a/tests/nilai-common/__init__.py b/tests/nilai-common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/nilai-common/test_api_model.py b/tests/nilai-common/test_api_model.py new file mode 100644 index 00000000..ac59db6e --- /dev/null +++ b/tests/nilai-common/test_api_model.py @@ -0,0 +1,54 @@ +import pytest +from pydantic import ValidationError +from nilai_common.api_model import ModelMetadata + + +def test_model_metadata_creation(): + """Test creating a ModelMetadata instance.""" + metadata = ModelMetadata( + name="Test Model", + version="1.0", + description="A test model", + author="Test Author", + license="MIT", + source="https://example.com", + supported_features=["feature1", "feature2"], + ) + + assert metadata.id is not None + assert metadata.name == "Test Model" + assert metadata.version == "1.0" + assert metadata.description == "A test model" + assert metadata.author == "Test Author" + assert metadata.license == "MIT" + assert metadata.source == "https://example.com" + assert metadata.supported_features == ["feature1", "feature2"] + + +def test_model_metadata_default_id(): + """Test that ModelMetadata generates a default UUID for id.""" + metadata = ModelMetadata( + name="Test Model", + version="1.0", + description="A test model", + author="Test Author", + license="MIT", + source="https://example.com", + supported_features=["feature1", "feature2"], + ) + + assert metadata.id is not None + assert len(metadata.id) == 36 # UUID length + + +def test_model_metadata_invalid_data(): + """Test creating ModelMetadata with invalid data.""" + with pytest.raises(ValidationError): + ModelMetadata( + name="", + version="", + description="", + author="", + license="", + source="", + ) # type: ignore diff --git a/tests/nilai-common/test_discovery.py b/tests/nilai-common/test_discovery.py new file mode 100644 index 00000000..d1714f5f --- /dev/null +++ b/tests/nilai-common/test_discovery.py @@ -0,0 +1,94 @@ +import pytest +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch +from nilai_common.discovery import ModelServiceDiscovery +from nilai_common.api_model import ModelEndpoint, ModelMetadata + + +@pytest.fixture +def model_service_discovery(): + with patch("nilai_common.discovery.Etcd3Client") as MockClient: + mock_client = MockClient.return_value + discovery = ModelServiceDiscovery() + discovery.client = mock_client + yield discovery + + +@pytest.fixture +def model_endpoint(): + model_metadata = ModelMetadata( + name="Test Model", + version="1.0.0", + description="Test model description", + author="Test Author", + license="MIT", + source="https://github.com/test/model", + supported_features=["test_feature"], + ) + return ModelEndpoint( + url="http://test-model-service.example.com/predict", metadata=model_metadata + ) + + +@pytest.mark.asyncio +async def test_register_model(model_service_discovery, model_endpoint): + lease_mock = MagicMock() + model_service_discovery.client.lease.return_value = lease_mock + + lease = await model_service_discovery.register_model(model_endpoint) + + model_service_discovery.client.put.assert_called_once_with( + f"/models/{model_endpoint.metadata.id}", + model_endpoint.model_dump_json(), + lease=lease_mock, + ) + assert lease == lease_mock + + +@pytest.mark.asyncio +async def test_discover_models(model_service_discovery, model_endpoint): + model_service_discovery.client.get_prefix.return_value = [ + (model_endpoint.model_dump_json().encode("utf-8"), None) + ] + + discovered_models = await model_service_discovery.discover_models() + + assert len(discovered_models) == 1 + assert model_endpoint.metadata.id in discovered_models + assert discovered_models[model_endpoint.metadata.id] == model_endpoint + + +@pytest.mark.asyncio +async def test_get_model(model_service_discovery, model_endpoint): + model_service_discovery.client.get.return_value = ( + model_endpoint.model_dump_json().encode("utf-8"), + None, + ) + + model = await model_service_discovery.get_model(model_endpoint.metadata.id) + + assert model == model_endpoint + + +@pytest.mark.asyncio +async def test_unregister_model(model_service_discovery, model_endpoint): + await model_service_discovery.unregister_model(model_endpoint.metadata.id) + + model_service_discovery.client.delete.assert_called_once_with( + f"/models/{model_endpoint.metadata.id}" + ) + + +@pytest.mark.asyncio +async def test_keep_alive(model_service_discovery): + lease_mock = MagicMock() + lease_mock.refresh = AsyncMock() + + async def keep_alive_task(): + await model_service_discovery.keep_alive(lease_mock) + + task = asyncio.create_task(keep_alive_task()) + await asyncio.sleep(0.1) + task.cancel() + + lease_mock.refresh.assert_called() diff --git a/tests/nilai_api/__init__.py b/tests/nilai_api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/nilai_api/routers/__init__.py b/tests/nilai_api/routers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/nilai_api/routers/test_private.py b/tests/nilai_api/routers/test_private.py new file mode 100644 index 00000000..7adc31f9 --- /dev/null +++ b/tests/nilai_api/routers/test_private.py @@ -0,0 +1,170 @@ +from unittest.mock import AsyncMock +from nilai_common.api_model import ( + ChatResponse, + Choice, + Message, + Usage, +) +import pytest +import asyncio +from fastapi.testclient import TestClient +from nilai_api.app import app +from nilai_api.db import UserManager +from nilai_api.state import state + +from tests import model_metadata, model_endpoint + +client = TestClient(app) + + +@pytest.mark.asyncio +async def test_runs_in_a_loop(): + assert asyncio.get_running_loop() + + +@pytest.fixture +def mock_user(): + return {"userid": "test-user-id", "name": "Test User"} + + +@pytest.fixture +def mock_user_manager(mocker): + mocker.patch.object( + UserManager, + "get_token_usage", + return_value={ + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + "queries": 10, + }, + ) + mocker.patch.object(UserManager, "update_token_usage") + mocker.patch.object( + UserManager, + "get_user_token_usage", + return_value={"prompt_tokens": 100, "completion_tokens": 50}, + ) + mocker.patch.object( + UserManager, + "insert_user", + return_value={"userid": "test-user-id", "apikey": "test-api-key"}, + ) + mocker.patch.object( + UserManager, + "check_api_key", + return_value={"name": "Test User", "userid": "test-user-id"}, + ) + mocker.patch.object( + UserManager, + "get_all_users", + return_value=[ + {"userid": "test-user-id", "apikey": "test-api-key"}, + {"userid": "test-user-id-2", "apikey": "test-api-key"}, + ], + ) + + +@pytest.fixture +def mock_state(mocker, event_loop): + # Prepare expected models data + + expected_models = {"ABC": model_endpoint} + + # Create a mock discovery service that returns the expected models + mock_discovery_service = mocker.Mock() + mock_discovery_service.discover_models = AsyncMock(return_value=expected_models) + + # Create a mock AppState + mocker.patch.object(state, "discovery_service", mock_discovery_service) + + # Patch other attributes + mocker.patch.object(state, "verifying_key", "test-verifying-key") + mocker.patch.object(state, "_cpu_quote", "test-cpu-attestation") + mocker.patch.object(state, "_gpu_quote", "test-gpu-attestation") + + # Patch get_model method + mocker.patch.object(state, "get_model", return_value=model_endpoint) + + return state + + +# Example test +@pytest.mark.asyncio +async def test_models_property(mock_state): + # Retrieve the models + models = await state.models + + # Assert the expected models + assert models == {"ABC": model_endpoint} + + +def test_get_usage(mock_user, mock_user_manager): + response = client.get("/v1/usage", headers={"Authorization": "Bearer test-api-key"}) + assert response.status_code == 200 + assert response.json() == { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + } + + +def test_get_attestation(mock_user, mock_user_manager, mock_state): + response = client.get( + "/v1/attestation/report", headers={"Authorization": "Bearer test-api-key"} + ) + assert response.status_code == 200 + assert response.json() == { + "verifying_key": "test-verifying-key", + "cpu_attestation": "test-cpu-attestation", + "gpu_attestation": "test-gpu-attestation", + } + + +def test_get_models(mock_user, mock_user_manager, mock_state): + response = client.get( + "/v1/models", headers={"Authorization": "Bearer test-api-key"} + ) + assert response.status_code == 200 + assert response.json() == [model_metadata.model_dump()] + + +def test_chat_completion(mock_user, mock_state, mock_user_manager, mocker): + response = ChatResponse( + id="test-id", + object="test-object", + model="test-model", + created=123456, + choices=[ + Choice( + index=0, + message=Message(role="test-role", content="test-content"), + finish_reason="test-finish-reason", + logprobs={"test-logprobs": "test-value"}, + ) + ], + usage=Usage(prompt_tokens=100, completion_tokens=50, total_tokens=150), + signature="test-signature", + ) + mocker.patch( + "httpx.AsyncClient.post", + return_value=mocker.Mock(status_code=200, content=response.model_dump_json()), + ) + response = client.post( + "/v1/chat/completions", + json={ + "model": "Llama-3.2-1B-Instruct", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is your name?"}, + ], + }, + headers={"Authorization": "Bearer test-api-key"}, + ) + assert response.status_code == 200 + assert "usage" in response.json() + assert response.json()["usage"] == { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + } diff --git a/tests/nilai_api/routers/test_public.py b/tests/nilai_api/routers/test_public.py new file mode 100644 index 00000000..3cb9c8c4 --- /dev/null +++ b/tests/nilai_api/routers/test_public.py @@ -0,0 +1,18 @@ +from fastapi.testclient import TestClient +from nilai_api.routers.public import router +from fastapi import FastAPI + +app = FastAPI() +app.include_router(router) + +client = TestClient(app) + + +def test_health_check(): + """Test the health check endpoint.""" + response = client.get("/v1/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert "uptime" in data + assert isinstance(data["uptime"], str) diff --git a/tests/nilai_api/sev/__init__.py b/tests/nilai_api/sev/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/nilai_api/sev/test_sev.py b/tests/nilai_api/sev/test_sev.py new file mode 100644 index 00000000..545559bd --- /dev/null +++ b/tests/nilai_api/sev/test_sev.py @@ -0,0 +1,64 @@ +import base64 +import ctypes +import pytest +from nilai_api.sev.sev import SEVGuest + + +@pytest.fixture +def sev_guest(): + return SEVGuest() + + +def test_init_success(sev_guest, mocker): + mocker.patch.object(sev_guest, "_load_library", return_value=None) + sev_guest.lib = mocker.Mock() + sev_guest.lib.Init.return_value = 0 + assert sev_guest.init() is True + + +def test_init_failure(sev_guest, mocker): + mocker.patch.object(sev_guest, "_load_library", return_value=None) + sev_guest.lib = mocker.Mock() + sev_guest.lib.Init.return_value = -1 + assert sev_guest.init() is False + + +def test_get_quote_success(sev_guest, mocker): + mocker.patch.object(sev_guest, "_load_library", return_value=None) + sev_guest.lib = mocker.Mock() + sev_guest.lib.GetQuote.return_value = ctypes.create_string_buffer(b"quote_data") + report_data = bytes([0] * 64) + quote = sev_guest.get_quote(report_data) + expected_quote = base64.b64encode(b"quote_data").decode("ascii") + assert quote == expected_quote + + +def test_get_quote_failure(sev_guest, mocker): + mocker.patch.object(sev_guest, "_load_library", return_value=None) + sev_guest.lib = mocker.Mock() + sev_guest.lib.GetQuote.return_value = None + report_data = bytes([0] * 64) + with pytest.raises(RuntimeError): + sev_guest.get_quote(report_data) + + +def test_get_quote_invalid_report_data(sev_guest): + if sev_guest.lib is not None: + with pytest.raises(ValueError): + sev_guest.get_quote(bytes([0] * 63)) + + +def test_verify_quote_success(sev_guest, mocker): + mocker.patch.object(sev_guest, "_load_library", return_value=None) + sev_guest.lib = mocker.Mock() + sev_guest.lib.VerifyQuote.return_value = 0 + quote = base64.b64encode(b"quote_data").decode("ascii") + assert sev_guest.verify_quote(quote) is True + + +def test_verify_quote_failure(sev_guest, mocker): + mocker.patch.object(sev_guest, "_load_library", return_value=None) + sev_guest.lib = mocker.Mock() + sev_guest.lib.VerifyQuote.return_value = -1 + quote = base64.b64encode(b"quote_data").decode("ascii") + assert sev_guest.verify_quote(quote) is False diff --git a/tests/nilai_api/test_app.py b/tests/nilai_api/test_app.py new file mode 100644 index 00000000..1b2e8622 --- /dev/null +++ b/tests/nilai_api/test_app.py @@ -0,0 +1,10 @@ +from fastapi.testclient import TestClient +from nilai_api.app import app + +client = TestClient(app) + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert "openapi" in response.json() diff --git a/tests/nilai_api/test_auth.py b/tests/nilai_api/test_auth.py new file mode 100644 index 00000000..b550fcbb --- /dev/null +++ b/tests/nilai_api/test_auth.py @@ -0,0 +1,38 @@ +import pytest +from fastapi import HTTPException +from fastapi.security import HTTPAuthorizationCredentials +from nilai_api.auth import get_user +from nilai_api.db import UserManager + + +@pytest.fixture +def mock_user_manager(mocker): + """Fixture to mock UserManager methods.""" + mocker.patch.object(UserManager, "check_api_key") + return UserManager + + +def test_get_user_valid_token(mock_user_manager): + """Test get_user with a valid token.""" + mock_user_manager.check_api_key.return_value = { + "name": "Test User", + "userid": "test-user-id", + } + credentials = HTTPAuthorizationCredentials( + scheme="Bearer", credentials="valid-token" + ) + user = get_user(credentials) + assert user["name"] == "Test User" + assert user["userid"] == "test-user-id" + + +def test_get_user_invalid_token(mock_user_manager): + """Test get_user with an invalid token.""" + mock_user_manager.check_api_key.return_value = None + credentials = HTTPAuthorizationCredentials( + scheme="Bearer", credentials="invalid-token" + ) + with pytest.raises(HTTPException) as exc_info: + get_user(credentials) + assert exc_info.value.status_code == 401 + assert exc_info.value.detail == "Missing or invalid API key" diff --git a/tests/test_cryptography.py b/tests/nilai_api/test_cryptography.py similarity index 96% rename from tests/test_cryptography.py rename to tests/nilai_api/test_cryptography.py index 702ca540..aa75dd6f 100644 --- a/tests/test_cryptography.py +++ b/tests/nilai_api/test_cryptography.py @@ -4,7 +4,7 @@ from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives.asymmetric import ec -from nilai.crypto import generate_key_pair, sign_message, verify_signature +from nilai_api.crypto import generate_key_pair, sign_message, verify_signature def test_generate_key_pair(): diff --git a/tests/test_db.py b/tests/nilai_api/test_db.py similarity index 90% rename from tests/test_db.py rename to tests/nilai_api/test_db.py index ae713125..e9b1d923 100644 --- a/tests/test_db.py +++ b/tests/nilai_api/test_db.py @@ -79,8 +79,8 @@ def test_insert_user(self, user_manager): # Verify user can be retrieved retrieved_user_tokens = user_manager.get_user_token_usage(user_data["userid"]) assert retrieved_user_tokens is not None - assert retrieved_user_tokens["input_tokens"] == 0 - assert retrieved_user_tokens["generated_tokens"] == 0 + assert retrieved_user_tokens["prompt_tokens"] == 0 + assert retrieved_user_tokens["completion_tokens"] == 0 def test_check_api_key(self, user_manager): """Test API key validation.""" @@ -89,7 +89,8 @@ def test_check_api_key(self, user_manager): # Check valid API key user_name = user_manager.check_api_key(user_data["apikey"]) - assert user_name == "Check API User" + assert user_name["name"] == "Check API User" + assert user_name["userid"] == user_data["userid"] # Check invalid API key invalid_result = user_manager.check_api_key("invalid-api-key") @@ -102,24 +103,24 @@ def test_update_token_usage(self, user_manager): # Update token usage user_manager.update_token_usage( - user_data["userid"], input_tokens=100, generated_tokens=50 + user_data["userid"], prompt_tokens=100, completion_tokens=50 ) # Verify token usage token_usage = user_manager.get_user_token_usage(user_data["userid"]) assert token_usage is not None - assert token_usage["input_tokens"] == 100 - assert token_usage["generated_tokens"] == 50 + assert token_usage["prompt_tokens"] == 100 + assert token_usage["completion_tokens"] == 50 # Update again to check cumulative effect user_manager.update_token_usage( - user_data["userid"], input_tokens=50, generated_tokens=25 + user_data["userid"], prompt_tokens=50, completion_tokens=25 ) token_usage = user_manager.get_user_token_usage(user_data["userid"]) assert token_usage is not None - assert token_usage["input_tokens"] == 150 - assert token_usage["generated_tokens"] == 75 + assert token_usage["prompt_tokens"] == 150 + assert token_usage["completion_tokens"] == 75 def test_get_all_users(self, user_manager): """Test retrieving all users.""" diff --git a/tests/nilai_api/test_state.py b/tests/nilai_api/test_state.py new file mode 100644 index 00000000..f8f6c689 --- /dev/null +++ b/tests/nilai_api/test_state.py @@ -0,0 +1,84 @@ +import pytest +from unittest.mock import patch, AsyncMock +from nilai_api.state import AppState + + +@pytest.fixture +def app_state(): + return AppState() + + +def test_generate_key_pair(app_state): + assert app_state.private_key is not None + assert app_state.public_key is not None + assert app_state.verifying_key is not None + + +def test_semaphore_initialization(app_state): + assert app_state.sem._value == 2 + + +def test_uptime(app_state): + uptime = app_state.uptime + assert ( + "days" in uptime + or "hours" in uptime + or "minutes" in uptime + or "seconds" in uptime + ) + + +@patch("nilai_api.state.sev.init") +@patch("nilai_api.state.sev.get_quote", return_value="mocked_quote") +def test_cpu_attestation(mock_get_quote, mock_init, app_state): + assert app_state.cpu_attestation == "mocked_quote" + mock_init.assert_called_once() + mock_get_quote.assert_called_once() + + +def test_gpu_attestation(app_state): + assert app_state.gpu_attestation == "" + + +@pytest.mark.asyncio +async def test_models(app_state): + with patch.object( + app_state.discovery_service, "discover_models", new_callable=AsyncMock + ) as mock_discover_models: + mock_discover_models.return_value = {"model1": "endpoint1"} + models = await app_state.models + assert models == {"model1": "endpoint1"} + mock_discover_models.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_model(app_state): + with patch.object( + app_state.discovery_service, "get_model", new_callable=AsyncMock + ) as mock_get_model: + mock_get_model.return_value = "endpoint1" + model = await app_state.get_model("model1") + assert model == "endpoint1" + mock_get_model.assert_awaited_once_with("model1") + + +@pytest.mark.asyncio +async def test_models_empty(app_state): + with patch.object( + app_state.discovery_service, "discover_models", new_callable=AsyncMock + ) as mock_discover_models: + mock_discover_models.return_value = {} + models = await app_state.models + assert models == {} + mock_discover_models.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_model_not_found(app_state): + with patch.object( + app_state.discovery_service, "get_model", new_callable=AsyncMock + ) as mock_get_model: + mock_get_model.return_value = None + model = await app_state.get_model("non_existent_model") + assert model is None + mock_get_model.assert_awaited_once_with("non_existent_model") diff --git a/tests/nilai_models/models/test_llama_1b_cpu.py b/tests/nilai_models/models/test_llama_1b_cpu.py new file mode 100644 index 00000000..14cf554e --- /dev/null +++ b/tests/nilai_models/models/test_llama_1b_cpu.py @@ -0,0 +1,67 @@ +import pytest +from fastapi import HTTPException +from unittest.mock import AsyncMock +from nilai_models.models.llama_1b_cpu.llama_1b_cpu import Llama1BCpu +from nilai_common import ChatRequest, Message, ChatResponse + +from tests import response as RESPONSE + + +@pytest.fixture +def llama_model(mocker): + """Fixture to provide a Llama1BCpu instance for testing.""" + model = Llama1BCpu() + mocker.patch.object(model, "chat_completion", new_callable=AsyncMock) + return model + + +@pytest.mark.asyncio +async def test_chat_completion_valid_request(llama_model): + """Test chat completion with a valid request.""" + req = ChatRequest( + model="bartowski/Llama-3.2-1B-Instruct-GGUF", + messages=[ + Message(role="system", content="You are a helpful assistant."), + Message(role="user", content="What is your name?"), + ], + ) + llama_model.chat_completion.return_value = RESPONSE + response = await llama_model.chat_completion(req) + assert isinstance(response, ChatResponse) + assert response.model == "test-model" + assert response.choices is not None + + +@pytest.mark.asyncio +async def test_chat_completion_empty_messages(llama_model): + """Test chat completion with an empty messages list.""" + req = ChatRequest( + model="bartowski/Llama-3.2-1B-Instruct-GGUF", + messages=[], + ) + llama_model.chat_completion.side_effect = HTTPException( + status_code=400, detail="The 'messages' field is required." + ) + with pytest.raises(HTTPException) as exc_info: + await llama_model.chat_completion(req) + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "The 'messages' field is required." + + +@pytest.mark.asyncio +async def test_chat_completion_missing_model(llama_model): + """Test chat completion with a missing model field.""" + req = ChatRequest( + model="", + messages=[ + Message(role="system", content="You are a helpful assistant."), + Message(role="user", content="What is your name?"), + ], + ) + llama_model.chat_completion.side_effect = HTTPException( + status_code=400, detail="The 'model' field is required." + ) + with pytest.raises(HTTPException) as exc_info: + await llama_model.chat_completion(req) + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "The 'model' field is required." diff --git a/tests/nilai_models/test_model.py b/tests/nilai_models/test_model.py new file mode 100644 index 00000000..36ea9259 --- /dev/null +++ b/tests/nilai_models/test_model.py @@ -0,0 +1,58 @@ +from nilai_common.api_model import Message +import pytest +from fastapi.testclient import TestClient +from nilai_models.model import Model +from nilai_common import ChatRequest, ChatResponse +from tests import model_metadata, response + + +class MyModel(Model): + async def chat_completion(self, req: ChatRequest) -> ChatResponse: + return response + + +@pytest.fixture +def model_instance(): + metadata = model_metadata + return MyModel(metadata) + + +@pytest.fixture +def client(model_instance): + return TestClient(model_instance.get_app()) + + +def test_model_info(client): + response = client.get("/v1/models") + assert response.status_code == 200 + data = response.json() + assert data["id"] == "ABC" + assert data["name"] == "ABC" + assert data["description"] == "Description" + + +def test_health_check(client): + response = client.get("/v1/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + assert "uptime" in data + + +def test_chat_completion(client): + request = ChatRequest( + model="ABC", + messages=[ + Message(role="system", content="You are a helpful assistant"), + Message(role="user", content="Hello, who are you?"), + ], + ) + response = client.post("/v1/chat/completions", json=request.model_dump()) + assert response.status_code == 200 + data = response.json() + assert "choices" in data + assert data["choices"][0]["finish_reason"] == "test-finish-reason" + assert data["usage"]["prompt_tokens"] == 100 + assert data["usage"]["completion_tokens"] == 50 + assert data["usage"]["total_tokens"] == 150 + assert data["signature"] == "test-signature" diff --git a/uv.lock b/uv.lock index ce3c5092..f3fd95e2 100644 --- a/uv.lock +++ b/uv.lock @@ -595,6 +595,8 @@ dev = [ { name = "black" }, { name = "isort" }, { name = "pytest" }, + { name = "pytest-asyncio" }, + { name = "pytest-mock" }, { name = "ruff" }, { name = "uvicorn" }, ] @@ -611,6 +613,8 @@ dev = [ { name = "black", specifier = ">=24.10.0" }, { name = "isort", specifier = ">=5.13.2" }, { name = "pytest", specifier = ">=8.3.3" }, + { name = "pytest-asyncio", specifier = ">=0.25.0" }, + { name = "pytest-mock", specifier = ">=3.14.0" }, { name = "ruff", specifier = ">=0.8.0" }, { name = "uvicorn", specifier = ">=0.32.1" }, ] @@ -992,6 +996,30 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6b/77/7440a06a8ead44c7757a64362dd22df5760f9b12dc5f11b6188cd2fc27a0/pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2", size = 342341 }, ] +[[package]] +name = "pytest-asyncio" +version = "0.25.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/94/18/82fcb4ee47d66d99f6cd1efc0b11b2a25029f303c599a5afda7c1bca4254/pytest_asyncio-0.25.0.tar.gz", hash = "sha256:8c0610303c9e0442a5db8604505fc0f545456ba1528824842b37b4a626cbf609", size = 53298 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/56/2ee0cab25c11d4e38738a2a98c645a8f002e2ecf7b5ed774c70d53b92bb1/pytest_asyncio-0.25.0-py3-none-any.whl", hash = "sha256:db5432d18eac6b7e28b46dcd9b69921b55c3b1086e85febfe04e70b18d9e81b3", size = 19245 }, +] + +[[package]] +name = "pytest-mock" +version = "3.14.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c6/90/a955c3ab35ccd41ad4de556596fa86685bf4fc5ffcc62d22d856cfd4e29a/pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0", size = 32814 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f2/3b/b26f90f74e2986a82df6e7ac7e319b8ea7ccece1caec9f8ab6104dc70603/pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f", size = 9863 }, +] + [[package]] name = "python-dotenv" version = "1.0.1"