From 22c1a3b570794878650ce8718bba437c061d61cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Mon, 16 Dec 2024 09:09:45 +0000 Subject: [PATCH 01/12] feat: improved context windows --- .../src/nilai_models/models/llama_1b_cpu/llama_1b_cpu.py | 1 + .../src/nilai_models/models/llama_8b_cpu/llama_8b_cpu.py | 1 + .../models/secret_llama_1b_cpu/secret_llama_1b_cpu.py | 1 + 3 files changed, 3 insertions(+) diff --git a/nilai-models/src/nilai_models/models/llama_1b_cpu/llama_1b_cpu.py b/nilai-models/src/nilai_models/models/llama_1b_cpu/llama_1b_cpu.py index e133a0b1..f11c8e6a 100644 --- a/nilai-models/src/nilai_models/models/llama_1b_cpu/llama_1b_cpu.py +++ b/nilai-models/src/nilai_models/models/llama_1b_cpu/llama_1b_cpu.py @@ -34,6 +34,7 @@ def __init__(self): repo_id="bartowski/Llama-3.2-1B-Instruct-GGUF", filename="Llama-3.2-1B-Instruct-Q5_K_S.gguf", n_threads=16, + n_ctx=2048, verbose=False, ) diff --git a/nilai-models/src/nilai_models/models/llama_8b_cpu/llama_8b_cpu.py b/nilai-models/src/nilai_models/models/llama_8b_cpu/llama_8b_cpu.py index 57026944..bb9802bd 100644 --- a/nilai-models/src/nilai_models/models/llama_8b_cpu/llama_8b_cpu.py +++ b/nilai-models/src/nilai_models/models/llama_8b_cpu/llama_8b_cpu.py @@ -34,6 +34,7 @@ def __init__(self): repo_id="bartowski/Meta-Llama-3-8B-Instruct-GGUF", filename="Meta-Llama-3-8B-Instruct-Q5_K_M.gguf", n_threads=16, + n_ctx=2048, verbose=False, ) diff --git a/nilai-models/src/nilai_models/models/secret_llama_1b_cpu/secret_llama_1b_cpu.py b/nilai-models/src/nilai_models/models/secret_llama_1b_cpu/secret_llama_1b_cpu.py index 70f4d43e..2a1b4103 100644 --- a/nilai-models/src/nilai_models/models/secret_llama_1b_cpu/secret_llama_1b_cpu.py +++ b/nilai-models/src/nilai_models/models/secret_llama_1b_cpu/secret_llama_1b_cpu.py @@ -34,6 +34,7 @@ def __init__(self): repo_id="bartowski/Llama-3.2-1B-Instruct-GGUF", filename="Llama-3.2-1B-Instruct-Q5_K_M.gguf", n_threads=16, + n_ctx=2048, verbose=False, ) From 364a829f431d2fab29cd3d89a6324ee69c06ef23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Mon, 16 Dec 2024 11:13:21 +0000 Subject: [PATCH 02/12] feat: added nilai-api tests --- nilai-api/src/nilai_api/routers/private.py | 8 +- nilai-api/src/nilai_api/state.py | 4 +- .../secret_llama_1b_cpu.py | 10 +- pyproject.toml | 2 + tests/model_execution_0.py | 38 ---- tests/model_execution_1.py | 49 ----- tests/model_execution_2.py | 34 ---- tests/model_execution_3.py | 25 --- tests/nilai_api/__init__.py | 0 tests/nilai_api/routers/__init__.py | 0 tests/nilai_api/routers/test_private.py | 182 ++++++++++++++++++ tests/nilai_api/routers/test_public.py | 20 ++ tests/nilai_api/test_app.py | 10 + tests/nilai_api/test_auth.py | 38 ++++ tests/{ => nilai_api}/test_cryptography.py | 2 +- tests/{ => nilai_api}/test_db.py | 19 +- tests/nilai_api/test_state.py | 89 +++++++++ uv.lock | 28 +++ 18 files changed, 392 insertions(+), 166 deletions(-) delete mode 100644 tests/model_execution_0.py delete mode 100644 tests/model_execution_1.py delete mode 100644 tests/model_execution_2.py delete mode 100644 tests/model_execution_3.py create mode 100644 tests/nilai_api/__init__.py create mode 100644 tests/nilai_api/routers/__init__.py create mode 100644 tests/nilai_api/routers/test_private.py create mode 100644 tests/nilai_api/routers/test_public.py create mode 100644 tests/nilai_api/test_app.py create mode 100644 tests/nilai_api/test_auth.py rename tests/{ => nilai_api}/test_cryptography.py (96%) rename tests/{ => nilai_api}/test_db.py (90%) create mode 100644 tests/nilai_api/test_state.py diff --git a/nilai-api/src/nilai_api/routers/private.py b/nilai-api/src/nilai_api/routers/private.py index 1923967c..0f6fb47c 100644 --- a/nilai-api/src/nilai_api/routers/private.py +++ b/nilai-api/src/nilai_api/routers/private.py @@ -30,7 +30,7 @@ async def get_usage(user: dict = Depends(get_user)) -> Usage: """ Retrieve the current token usage for the authenticated user. - - **user**: Authenticated user information (through X-API-Key header) + - **user**: Authenticated user information (through HTTP Bearer header) - **Returns**: Usage statistics for the user's token consumption ### Example @@ -47,7 +47,7 @@ async def get_attestation(user: dict = Depends(get_user)) -> AttestationResponse """ Generate a cryptographic attestation report. - - **user**: Authenticated user information (through X-API-Key header) + - **user**: Authenticated user information (through HTTP Bearer header) - **Returns**: Attestation details for service verification ### Attestation Details @@ -70,7 +70,7 @@ async def get_models(user: dict = Depends(get_user)) -> list[ModelMetadata]: """ List all available models in the system. - - **user**: Authenticated user information (through X-API-Key header) + - **user**: Authenticated user information (through HTTP Bearer header) - **Returns**: Dictionary of available models ### Example @@ -100,7 +100,7 @@ async def chat_completion( Generate a chat completion response from the AI model. - **req**: Chat completion request containing messages and model specifications - - **user**: Authenticated user information (through X-API-Key header) + - **user**: Authenticated user information (through HTTP Bearer header) - **Returns**: Full chat response with model output, usage statistics, and cryptographic signature ### Request Requirements diff --git a/nilai-api/src/nilai_api/state.py b/nilai-api/src/nilai_api/state.py index d023aa84..5795cf10 100644 --- a/nilai-api/src/nilai_api/state.py +++ b/nilai-api/src/nilai_api/state.py @@ -22,7 +22,7 @@ def __init__(self): ) self._uptime = time.time() self._cpu_quote = None - self._gpu_quote = None + self._gpu_quote = "" @property def cpu_attestation(self) -> str: @@ -36,7 +36,7 @@ def cpu_attestation(self) -> str: @property def gpu_attestation(self) -> str: - return "" + return self._gpu_quote @property def uptime(self): diff --git a/nilai-models/src/nilai_models/models/secret_llama_1b_cpu/secret_llama_1b_cpu.py b/nilai-models/src/nilai_models/models/secret_llama_1b_cpu/secret_llama_1b_cpu.py index 2a1b4103..8084caca 100644 --- a/nilai-models/src/nilai_models/models/secret_llama_1b_cpu/secret_llama_1b_cpu.py +++ b/nilai-models/src/nilai_models/models/secret_llama_1b_cpu/secret_llama_1b_cpu.py @@ -95,10 +95,12 @@ async def chat_completion( for msg in req.messages ] - prompt += [{ - "role": "system", - "content": "In addition to the previous. You are a cheese expert. You use cheese for all your answers. Whatever the user asks, you respond with a cheese-related answer or analogy.", - }] + prompt += [ + { + "role": "system", + "content": "In addition to the previous. You are a cheese expert. You use cheese for all your answers. Whatever the user asks, you respond with a cheese-related answer or analogy.", + } + ] # Generate chat completion using the Llama model # - Converts messages into a model-compatible prompt # - type: ignore suppresses type checking for external library diff --git a/pyproject.toml b/pyproject.toml index dd835824..f49adc47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,9 +17,11 @@ dependencies = [ dev = [ "black>=24.10.0", "isort>=5.13.2", + "pytest-mock>=3.14.0", "pytest>=8.3.3", "ruff>=0.8.0", "uvicorn>=0.32.1", + "pytest-asyncio>=0.25.0", ] [build-system] diff --git a/tests/model_execution_0.py b/tests/model_execution_0.py deleted file mode 100644 index f7bbe1ae..00000000 --- a/tests/model_execution_0.py +++ /dev/null @@ -1,38 +0,0 @@ -import os -import time - -import torch -from dotenv import load_dotenv -from transformers import pipeline - -# Load the .env file -load_dotenv() - -# # Application State Initialization -torch.set_num_threads(32) -torch.set_num_interop_threads(32) - - -chat_pipeline = pipeline( - "text-generation", - model="meta-llama/Llama-3.2-1B-Instruct", - model_kwargs={"torch_dtype": torch.bfloat16}, - device_map="cpu", - token=os.getenv("HUGGINGFACE_API_TOKEN"), -) - -messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is your name?"}, -] - -start = time.time() -# Generate response -generated = chat_pipeline( - messages, max_length=1024, num_return_sequences=1, truncation=True -) # type: ignore - -end = time.time() - -print(generated) -print(end - start) diff --git a/tests/model_execution_1.py b/tests/model_execution_1.py deleted file mode 100644 index 2ef17ab4..00000000 --- a/tests/model_execution_1.py +++ /dev/null @@ -1,49 +0,0 @@ -import time - -from onnxruntime import InferenceSession -from optimum.onnxruntime import ORTModelForCausalLM -from transformers import AutoTokenizer - -# Define the model directory and ONNX export location -model_name = "meta-llama/Llama-3.2-1B-Instruct" -onnx_export_dir = "./onnx_model" - -# Export the model -model = ORTModelForCausalLM.from_pretrained(model_name, from_transformers=True) -model.save_pretrained(onnx_export_dir) - -# Save the tokenizer for later use -tokenizer = AutoTokenizer.from_pretrained(model_name) -tokenizer.save_pretrained(onnx_export_dir) - - -# Load the ONNX model and tokenizer -onnx_model_path = "./onnx_model/model.onnx" -tokenizer = AutoTokenizer.from_pretrained("./onnx_model") - -# Create an ONNX Runtime session -session = InferenceSession(onnx_model_path) - -# Input messages -messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is your name?"}, -] - -# Prepare input text -input_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) - -# Tokenize input text -inputs = tokenizer(input_text, return_tensors="pt") -print("START:") -# Run inference -start = time.time() -onnx_inputs = {session.get_inputs()[0].name: inputs["input_ids"].numpy()} -onnx_output = session.run(None, onnx_inputs) -end = time.time() - -# Decode the output -output_text = tokenizer.decode(onnx_output[0][0], skip_special_tokens=True) - -print(output_text) -print(f"Time taken: {end - start} seconds") diff --git a/tests/model_execution_2.py b/tests/model_execution_2.py deleted file mode 100644 index 821279ae..00000000 --- a/tests/model_execution_2.py +++ /dev/null @@ -1,34 +0,0 @@ -import os -import time - -from dotenv import load_dotenv -from optimum.pipelines import pipeline - -# Load the .env file -load_dotenv() - - -chat_pipeline = pipeline( - "text-generation", - model="meta-llama/Llama-3.2-1B-Instruct", - accelerator="ort", - token=os.getenv("HUGGINGFACE_API_TOKEN"), -) - -messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is your name?"}, -] - -print("start") -for i in range(10): - start = time.time() - # Generate response - generated = chat_pipeline( - messages, max_length=1024, num_return_sequences=1, truncation=True - ) # type: ignore - - end = time.time() - - print(generated) - print(end - start) diff --git a/tests/model_execution_3.py b/tests/model_execution_3.py deleted file mode 100644 index 33c4ace7..00000000 --- a/tests/model_execution_3.py +++ /dev/null @@ -1,25 +0,0 @@ -import time - -from llama_cpp import Llama - -llm = Llama.from_pretrained( - repo_id="bartowski/Llama-3.2-1B-Instruct-GGUF", - filename="Llama-3.2-1B-Instruct-Q5_K_S.gguf", -) - - -messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is your name?"}, -] - -print("start") -for i in range(10): - start = time.time() - # Generate response - generated = llm.create_chat_completion(messages) - - end = time.time() - - print(generated) - print(end - start) diff --git a/tests/nilai_api/__init__.py b/tests/nilai_api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/nilai_api/routers/__init__.py b/tests/nilai_api/routers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/nilai_api/routers/test_private.py b/tests/nilai_api/routers/test_private.py new file mode 100644 index 00000000..cfe850ef --- /dev/null +++ b/tests/nilai_api/routers/test_private.py @@ -0,0 +1,182 @@ +from unittest.mock import AsyncMock +from nilai_common.api_model import ( + ChatResponse, + Choice, + Message, + ModelEndpoint, + ModelMetadata, + Usage, +) +import pytest +import asyncio +from fastapi.testclient import TestClient +from nilai_api.app import app +from nilai_api.db import UserManager +from nilai_api.state import state + +client = TestClient(app) +model_metadata = ModelMetadata( + id="ABC", # Unique identifier + name="ABC", # Human-readable name + version="1.0", # Model version + description="Description", + author="Author", # Model creators + license="License", # Usage license + source="http://test-model-url", # Model source + supported_features=["supported_feature"], # Capabilities +) + +model_endpoint = ModelEndpoint(url="http://test-model-url", metadata=model_metadata) + + +@pytest.mark.asyncio +async def test_runs_in_a_loop(): + assert asyncio.get_running_loop() + + +@pytest.fixture +def mock_user(): + return {"userid": "test-user-id", "name": "Test User"} + + +@pytest.fixture +def mock_user_manager(mocker): + mocker.patch.object( + UserManager, + "get_token_usage", + return_value={ + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + "queries": 10, + }, + ) + mocker.patch.object(UserManager, "update_token_usage") + mocker.patch.object( + UserManager, + "get_user_token_usage", + return_value={"prompt_tokens": 100, "completion_tokens": 50}, + ) + mocker.patch.object( + UserManager, + "insert_user", + return_value={"userid": "test-user-id", "apikey": "test-api-key"}, + ) + mocker.patch.object( + UserManager, + "check_api_key", + return_value={"name": "Test User", "userid": "test-user-id"}, + ) + mocker.patch.object( + UserManager, + "get_all_users", + return_value=[ + {"userid": "test-user-id", "apikey": "test-api-key"}, + {"userid": "test-user-id-2", "apikey": "test-api-key"}, + ], + ) + + +@pytest.fixture +def mock_state(mocker, event_loop): + # Prepare expected models data + + expected_models = {"ABC": model_endpoint} + + # Create a mock discovery service that returns the expected models + mock_discovery_service = mocker.Mock() + mock_discovery_service.discover_models = AsyncMock(return_value=expected_models) + + # Create a mock AppState + mocker.patch.object(state, "discovery_service", mock_discovery_service) + + # Patch other attributes + mocker.patch.object(state, "verifying_key", "test-verifying-key") + mocker.patch.object(state, "_cpu_quote", "test-cpu-attestation") + mocker.patch.object(state, "_gpu_quote", "test-gpu-attestation") + + # Patch get_model method + mocker.patch.object(state, "get_model", return_value=model_endpoint) + + return state + + +# Example test +@pytest.mark.asyncio +async def test_models_property(mock_state): + # Retrieve the models + models = await state.models + + # Assert the expected models + assert models == {"ABC": model_endpoint} + + +def test_get_usage(mock_user, mock_user_manager): + response = client.get("/v1/usage", headers={"Authorization": "Bearer test-api-key"}) + assert response.status_code == 200 + assert response.json() == { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + } + + +def test_get_attestation(mock_user, mock_user_manager, mock_state): + response = client.get( + "/v1/attestation/report", headers={"Authorization": "Bearer test-api-key"} + ) + assert response.status_code == 200 + assert response.json() == { + "verifying_key": "test-verifying-key", + "cpu_attestation": "test-cpu-attestation", + "gpu_attestation": "test-gpu-attestation", + } + + +def test_get_models(mock_user, mock_user_manager, mock_state): + response = client.get( + "/v1/models", headers={"Authorization": "Bearer test-api-key"} + ) + assert response.status_code == 200 + assert response.json() == [model_metadata.model_dump()] + + +def test_chat_completion(mock_user, mock_state, mock_user_manager, mocker): + response = ChatResponse( + id="test-id", + object="test-object", + model="test-model", + created=123456, + choices=[ + Choice( + index=0, + message=Message(role="test-role", content="test-content"), + finish_reason="test-finish-reason", + logprobs={"test-logprobs": "test-value"}, + ) + ], + usage=Usage(prompt_tokens=100, completion_tokens=50, total_tokens=150), + signature="test-signature", + ) + mocker.patch( + "httpx.AsyncClient.post", + return_value=mocker.Mock(status_code=200, content=response.model_dump_json()), + ) + response = client.post( + "/v1/chat/completions", + json={ + "model": "Llama-3.2-1B-Instruct", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is your name?"}, + ], + }, + headers={"Authorization": "Bearer test-api-key"}, + ) + assert response.status_code == 200 + assert "usage" in response.json() + assert response.json()["usage"] == { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + } diff --git a/tests/nilai_api/routers/test_public.py b/tests/nilai_api/routers/test_public.py new file mode 100644 index 00000000..8f308f9b --- /dev/null +++ b/tests/nilai_api/routers/test_public.py @@ -0,0 +1,20 @@ +import pytest +from fastapi.testclient import TestClient +from nilai_api.routers.public import router +from nilai_api.state import state +from nilai_common import HealthCheckResponse +from fastapi import FastAPI + +app = FastAPI() +app.include_router(router) + +client = TestClient(app) + +def test_health_check(): + """Test the health check endpoint.""" + response = client.get("/v1/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert "uptime" in data + assert isinstance(data["uptime"], str) \ No newline at end of file diff --git a/tests/nilai_api/test_app.py b/tests/nilai_api/test_app.py new file mode 100644 index 00000000..1b2e8622 --- /dev/null +++ b/tests/nilai_api/test_app.py @@ -0,0 +1,10 @@ +from fastapi.testclient import TestClient +from nilai_api.app import app + +client = TestClient(app) + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert "openapi" in response.json() diff --git a/tests/nilai_api/test_auth.py b/tests/nilai_api/test_auth.py new file mode 100644 index 00000000..b550fcbb --- /dev/null +++ b/tests/nilai_api/test_auth.py @@ -0,0 +1,38 @@ +import pytest +from fastapi import HTTPException +from fastapi.security import HTTPAuthorizationCredentials +from nilai_api.auth import get_user +from nilai_api.db import UserManager + + +@pytest.fixture +def mock_user_manager(mocker): + """Fixture to mock UserManager methods.""" + mocker.patch.object(UserManager, "check_api_key") + return UserManager + + +def test_get_user_valid_token(mock_user_manager): + """Test get_user with a valid token.""" + mock_user_manager.check_api_key.return_value = { + "name": "Test User", + "userid": "test-user-id", + } + credentials = HTTPAuthorizationCredentials( + scheme="Bearer", credentials="valid-token" + ) + user = get_user(credentials) + assert user["name"] == "Test User" + assert user["userid"] == "test-user-id" + + +def test_get_user_invalid_token(mock_user_manager): + """Test get_user with an invalid token.""" + mock_user_manager.check_api_key.return_value = None + credentials = HTTPAuthorizationCredentials( + scheme="Bearer", credentials="invalid-token" + ) + with pytest.raises(HTTPException) as exc_info: + get_user(credentials) + assert exc_info.value.status_code == 401 + assert exc_info.value.detail == "Missing or invalid API key" diff --git a/tests/test_cryptography.py b/tests/nilai_api/test_cryptography.py similarity index 96% rename from tests/test_cryptography.py rename to tests/nilai_api/test_cryptography.py index 702ca540..aa75dd6f 100644 --- a/tests/test_cryptography.py +++ b/tests/nilai_api/test_cryptography.py @@ -4,7 +4,7 @@ from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives.asymmetric import ec -from nilai.crypto import generate_key_pair, sign_message, verify_signature +from nilai_api.crypto import generate_key_pair, sign_message, verify_signature def test_generate_key_pair(): diff --git a/tests/test_db.py b/tests/nilai_api/test_db.py similarity index 90% rename from tests/test_db.py rename to tests/nilai_api/test_db.py index ae713125..e9b1d923 100644 --- a/tests/test_db.py +++ b/tests/nilai_api/test_db.py @@ -79,8 +79,8 @@ def test_insert_user(self, user_manager): # Verify user can be retrieved retrieved_user_tokens = user_manager.get_user_token_usage(user_data["userid"]) assert retrieved_user_tokens is not None - assert retrieved_user_tokens["input_tokens"] == 0 - assert retrieved_user_tokens["generated_tokens"] == 0 + assert retrieved_user_tokens["prompt_tokens"] == 0 + assert retrieved_user_tokens["completion_tokens"] == 0 def test_check_api_key(self, user_manager): """Test API key validation.""" @@ -89,7 +89,8 @@ def test_check_api_key(self, user_manager): # Check valid API key user_name = user_manager.check_api_key(user_data["apikey"]) - assert user_name == "Check API User" + assert user_name["name"] == "Check API User" + assert user_name["userid"] == user_data["userid"] # Check invalid API key invalid_result = user_manager.check_api_key("invalid-api-key") @@ -102,24 +103,24 @@ def test_update_token_usage(self, user_manager): # Update token usage user_manager.update_token_usage( - user_data["userid"], input_tokens=100, generated_tokens=50 + user_data["userid"], prompt_tokens=100, completion_tokens=50 ) # Verify token usage token_usage = user_manager.get_user_token_usage(user_data["userid"]) assert token_usage is not None - assert token_usage["input_tokens"] == 100 - assert token_usage["generated_tokens"] == 50 + assert token_usage["prompt_tokens"] == 100 + assert token_usage["completion_tokens"] == 50 # Update again to check cumulative effect user_manager.update_token_usage( - user_data["userid"], input_tokens=50, generated_tokens=25 + user_data["userid"], prompt_tokens=50, completion_tokens=25 ) token_usage = user_manager.get_user_token_usage(user_data["userid"]) assert token_usage is not None - assert token_usage["input_tokens"] == 150 - assert token_usage["generated_tokens"] == 75 + assert token_usage["prompt_tokens"] == 150 + assert token_usage["completion_tokens"] == 75 def test_get_all_users(self, user_manager): """Test retrieving all users.""" diff --git a/tests/nilai_api/test_state.py b/tests/nilai_api/test_state.py new file mode 100644 index 00000000..1325e79f --- /dev/null +++ b/tests/nilai_api/test_state.py @@ -0,0 +1,89 @@ +import pytest +from unittest.mock import patch, AsyncMock +from nilai_api.state import AppState + + +@pytest.fixture +def app_state(): + return AppState() + + +def test_generate_key_pair(app_state): + assert app_state.private_key is not None + assert app_state.public_key is not None + assert app_state.verifying_key is not None + + +def test_semaphore_initialization(app_state): + assert app_state.sem._value == 2 + + +def test_uptime(app_state): + uptime = app_state.uptime + assert ( + "days" in uptime + or "hours" in uptime + or "minutes" in uptime + or "seconds" in uptime + ) + + +@patch("nilai_api.state.init") +@patch("nilai_api.state.get_quote", return_value="mocked_quote") +def test_cpu_attestation(mock_get_quote, mock_init, app_state): + assert app_state.cpu_attestation == "mocked_quote" + mock_init.assert_called_once() + mock_get_quote.assert_called_once() + + +def test_cpu_attestation_non_tee(app_state): + with patch("nilai_api.state.init", side_effect=RuntimeError): + assert app_state.cpu_attestation == "" + + +def test_gpu_attestation(app_state): + assert app_state.gpu_attestation == "" + + +@pytest.mark.asyncio +async def test_models(app_state): + with patch.object( + app_state.discovery_service, "discover_models", new_callable=AsyncMock + ) as mock_discover_models: + mock_discover_models.return_value = {"model1": "endpoint1"} + models = await app_state.models + assert models == {"model1": "endpoint1"} + mock_discover_models.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_model(app_state): + with patch.object( + app_state.discovery_service, "get_model", new_callable=AsyncMock + ) as mock_get_model: + mock_get_model.return_value = "endpoint1" + model = await app_state.get_model("model1") + assert model == "endpoint1" + mock_get_model.assert_awaited_once_with("model1") + + +@pytest.mark.asyncio +async def test_models_empty(app_state): + with patch.object( + app_state.discovery_service, "discover_models", new_callable=AsyncMock + ) as mock_discover_models: + mock_discover_models.return_value = {} + models = await app_state.models + assert models == {} + mock_discover_models.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_model_not_found(app_state): + with patch.object( + app_state.discovery_service, "get_model", new_callable=AsyncMock + ) as mock_get_model: + mock_get_model.return_value = None + model = await app_state.get_model("non_existent_model") + assert model is None + mock_get_model.assert_awaited_once_with("non_existent_model") diff --git a/uv.lock b/uv.lock index ce3c5092..f3fd95e2 100644 --- a/uv.lock +++ b/uv.lock @@ -595,6 +595,8 @@ dev = [ { name = "black" }, { name = "isort" }, { name = "pytest" }, + { name = "pytest-asyncio" }, + { name = "pytest-mock" }, { name = "ruff" }, { name = "uvicorn" }, ] @@ -611,6 +613,8 @@ dev = [ { name = "black", specifier = ">=24.10.0" }, { name = "isort", specifier = ">=5.13.2" }, { name = "pytest", specifier = ">=8.3.3" }, + { name = "pytest-asyncio", specifier = ">=0.25.0" }, + { name = "pytest-mock", specifier = ">=3.14.0" }, { name = "ruff", specifier = ">=0.8.0" }, { name = "uvicorn", specifier = ">=0.32.1" }, ] @@ -992,6 +996,30 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6b/77/7440a06a8ead44c7757a64362dd22df5760f9b12dc5f11b6188cd2fc27a0/pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2", size = 342341 }, ] +[[package]] +name = "pytest-asyncio" +version = "0.25.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/94/18/82fcb4ee47d66d99f6cd1efc0b11b2a25029f303c599a5afda7c1bca4254/pytest_asyncio-0.25.0.tar.gz", hash = "sha256:8c0610303c9e0442a5db8604505fc0f545456ba1528824842b37b4a626cbf609", size = 53298 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/56/2ee0cab25c11d4e38738a2a98c645a8f002e2ecf7b5ed774c70d53b92bb1/pytest_asyncio-0.25.0-py3-none-any.whl", hash = "sha256:db5432d18eac6b7e28b46dcd9b69921b55c3b1086e85febfe04e70b18d9e81b3", size = 19245 }, +] + +[[package]] +name = "pytest-mock" +version = "3.14.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c6/90/a955c3ab35ccd41ad4de556596fa86685bf4fc5ffcc62d22d856cfd4e29a/pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0", size = 32814 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f2/3b/b26f90f74e2986a82df6e7ac7e319b8ea7ccece1caec9f8ab6104dc70603/pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f", size = 9863 }, +] + [[package]] name = "python-dotenv" version = "1.0.1" From 12aaf543b9f685c9f8a131838c1b8a636e994b7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Mon, 16 Dec 2024 11:24:20 +0000 Subject: [PATCH 03/12] feat: added nilai-common tests --- .../nilai-common/src/nilai_common/__init__.py | 2 +- .../src/nilai_common/{db.py => discovery.py} | 0 tests/nilai-common/__init__.py | 0 tests/nilai-common/test_api_model.py | 51 ++++++++++++ tests/nilai-common/test_discovery.py | 82 +++++++++++++++++++ 5 files changed, 134 insertions(+), 1 deletion(-) rename packages/nilai-common/src/nilai_common/{db.py => discovery.py} (100%) create mode 100644 tests/nilai-common/__init__.py create mode 100644 tests/nilai-common/test_api_model.py create mode 100644 tests/nilai-common/test_discovery.py diff --git a/packages/nilai-common/src/nilai_common/__init__.py b/packages/nilai-common/src/nilai_common/__init__.py index 4cf49984..b07568c1 100644 --- a/packages/nilai-common/src/nilai_common/__init__.py +++ b/packages/nilai-common/src/nilai_common/__init__.py @@ -9,7 +9,7 @@ Usage, ) from nilai_common.config import SETTINGS -from nilai_common.db import ModelServiceDiscovery +from nilai_common.discovery import ModelServiceDiscovery __all__ = [ "Message", diff --git a/packages/nilai-common/src/nilai_common/db.py b/packages/nilai-common/src/nilai_common/discovery.py similarity index 100% rename from packages/nilai-common/src/nilai_common/db.py rename to packages/nilai-common/src/nilai_common/discovery.py diff --git a/tests/nilai-common/__init__.py b/tests/nilai-common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/nilai-common/test_api_model.py b/tests/nilai-common/test_api_model.py new file mode 100644 index 00000000..fff3a9b5 --- /dev/null +++ b/tests/nilai-common/test_api_model.py @@ -0,0 +1,51 @@ +import pytest +from pydantic import ValidationError +from nilai_common.api_model import ModelMetadata + +def test_model_metadata_creation(): + """Test creating a ModelMetadata instance.""" + metadata = ModelMetadata( + name="Test Model", + version="1.0", + description="A test model", + author="Test Author", + license="MIT", + source="https://example.com", + supported_features=["feature1", "feature2"] + ) + + assert metadata.id is not None + assert metadata.name == "Test Model" + assert metadata.version == "1.0" + assert metadata.description == "A test model" + assert metadata.author == "Test Author" + assert metadata.license == "MIT" + assert metadata.source == "https://example.com" + assert metadata.supported_features == ["feature1", "feature2"] + +def test_model_metadata_default_id(): + """Test that ModelMetadata generates a default UUID for id.""" + metadata = ModelMetadata( + name="Test Model", + version="1.0", + description="A test model", + author="Test Author", + license="MIT", + source="https://example.com", + supported_features=["feature1", "feature2"] + ) + + assert metadata.id is not None + assert len(metadata.id) == 36 # UUID length + +def test_model_metadata_invalid_data(): + """Test creating ModelMetadata with invalid data.""" + with pytest.raises(ValidationError): + ModelMetadata( + name="", + version="", + description="", + author="", + license="", + source="", + ) # type: ignore \ No newline at end of file diff --git a/tests/nilai-common/test_discovery.py b/tests/nilai-common/test_discovery.py new file mode 100644 index 00000000..0e36bb03 --- /dev/null +++ b/tests/nilai-common/test_discovery.py @@ -0,0 +1,82 @@ +import pytest +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch +from nilai_common.discovery import ModelServiceDiscovery +from nilai_common.api_model import ModelEndpoint, ModelMetadata + +@pytest.fixture +def model_service_discovery(): + with patch('nilai_common.discovery.Etcd3Client') as MockClient: + mock_client = MockClient.return_value + discovery = ModelServiceDiscovery() + discovery.client = mock_client + yield discovery + +@pytest.fixture +def model_endpoint(): + model_metadata = ModelMetadata( + name="Test Model", + version="1.0.0", + description="Test model description", + author="Test Author", + license="MIT", + source="https://github.com/test/model", + supported_features=["test_feature"], + ) + return ModelEndpoint( + url="http://test-model-service.example.com/predict", metadata=model_metadata + ) + +@pytest.mark.asyncio +async def test_register_model(model_service_discovery, model_endpoint): + lease_mock = MagicMock() + model_service_discovery.client.lease.return_value = lease_mock + + lease = await model_service_discovery.register_model(model_endpoint) + + model_service_discovery.client.put.assert_called_once_with( + f"/models/{model_endpoint.metadata.id}", + model_endpoint.model_dump_json(), + lease=lease_mock + ) + assert lease == lease_mock + +@pytest.mark.asyncio +async def test_discover_models(model_service_discovery, model_endpoint): + model_service_discovery.client.get_prefix.return_value = [ + (model_endpoint.model_dump_json().encode('utf-8'), None) + ] + + discovered_models = await model_service_discovery.discover_models() + + assert len(discovered_models) == 1 + assert model_endpoint.metadata.id in discovered_models + assert discovered_models[model_endpoint.metadata.id] == model_endpoint + +@pytest.mark.asyncio +async def test_get_model(model_service_discovery, model_endpoint): + model_service_discovery.client.get.return_value = (model_endpoint.model_dump_json().encode('utf-8'), None) + + model = await model_service_discovery.get_model(model_endpoint.metadata.id) + + assert model == model_endpoint + +@pytest.mark.asyncio +async def test_unregister_model(model_service_discovery, model_endpoint): + await model_service_discovery.unregister_model(model_endpoint.metadata.id) + + model_service_discovery.client.delete.assert_called_once_with(f"/models/{model_endpoint.metadata.id}") + +@pytest.mark.asyncio +async def test_keep_alive(model_service_discovery): + lease_mock = MagicMock() + lease_mock.refresh = AsyncMock() + + async def keep_alive_task(): + await model_service_discovery.keep_alive(lease_mock) + + task = asyncio.create_task(keep_alive_task()) + await asyncio.sleep(0.1) + task.cancel() + + lease_mock.refresh.assert_called() \ No newline at end of file From a38f3a10038cf57b7688d104e1f45766e414f46f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Mon, 16 Dec 2024 11:58:06 +0000 Subject: [PATCH 04/12] feat: added nilai_models tests --- .../nilai-common/src/nilai_common/__init__.py | 2 + tests/__init__.py | 31 ++++++++++ tests/nilai_api/routers/test_private.py | 14 +---- .../nilai_models/models/test_llama_1b_cpu.py | 60 +++++++++++++++++++ tests/nilai_models/test_model.py | 52 ++++++++++++++++ 5 files changed, 147 insertions(+), 12 deletions(-) create mode 100644 tests/nilai_models/models/test_llama_1b_cpu.py create mode 100644 tests/nilai_models/test_model.py diff --git a/packages/nilai-common/src/nilai_common/__init__.py b/packages/nilai-common/src/nilai_common/__init__.py index b07568c1..44e39f63 100644 --- a/packages/nilai-common/src/nilai_common/__init__.py +++ b/packages/nilai-common/src/nilai_common/__init__.py @@ -2,6 +2,7 @@ AttestationResponse, ChatRequest, ChatResponse, + Choice, HealthCheckResponse, Message, ModelEndpoint, @@ -15,6 +16,7 @@ "Message", "ChatRequest", "ChatResponse", + "Choice", "ModelMetadata", "Usage", "AttestationResponse", diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..ad0ef970 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,31 @@ +from nilai_common import ChatResponse, Choice, Message, ModelEndpoint, ModelMetadata, Usage + +model_metadata: ModelMetadata = ModelMetadata( + id="ABC", # Unique identifier + name="ABC", # Human-readable name + version="1.0", # Model version + description="Description", + author="Author", # Model creators + license="License", # Usage license + source="http://test-model-url", # Model source + supported_features=["supported_feature"], # Capabilities +) + +model_endpoint: ModelEndpoint = ModelEndpoint(url="http://test-model-url", metadata=model_metadata) + +response: ChatResponse = ChatResponse( + id="test-id", + object="test-object", + model="test-model", + created=123456, + choices=[ + Choice( + index=0, + message=Message(role="test-role", content="test-content"), + finish_reason="test-finish-reason", + logprobs={"test-logprobs": "test-value"}, + ) + ], + usage=Usage(prompt_tokens=100, completion_tokens=50, total_tokens=150), + signature="test-signature", +) \ No newline at end of file diff --git a/tests/nilai_api/routers/test_private.py b/tests/nilai_api/routers/test_private.py index cfe850ef..fea013ea 100644 --- a/tests/nilai_api/routers/test_private.py +++ b/tests/nilai_api/routers/test_private.py @@ -14,19 +14,9 @@ from nilai_api.db import UserManager from nilai_api.state import state -client = TestClient(app) -model_metadata = ModelMetadata( - id="ABC", # Unique identifier - name="ABC", # Human-readable name - version="1.0", # Model version - description="Description", - author="Author", # Model creators - license="License", # Usage license - source="http://test-model-url", # Model source - supported_features=["supported_feature"], # Capabilities -) +from tests import model_metadata, model_endpoint -model_endpoint = ModelEndpoint(url="http://test-model-url", metadata=model_metadata) +client = TestClient(app) @pytest.mark.asyncio diff --git a/tests/nilai_models/models/test_llama_1b_cpu.py b/tests/nilai_models/models/test_llama_1b_cpu.py new file mode 100644 index 00000000..e1d526f4 --- /dev/null +++ b/tests/nilai_models/models/test_llama_1b_cpu.py @@ -0,0 +1,60 @@ +import pytest +from fastapi import HTTPException +from unittest.mock import AsyncMock +from nilai_models.models.llama_1b_cpu.llama_1b_cpu import Llama1BCpu +from nilai_common import ChatRequest, Message, ChatResponse + +from tests import model_metadata, model_endpoint, response as RESPONSE + + +@pytest.fixture +def llama_model(mocker): + """Fixture to provide a Llama1BCpu instance for testing.""" + model = Llama1BCpu() + mocker.patch.object(model, 'chat_completion', new_callable=AsyncMock) + return model + +@pytest.mark.asyncio +async def test_chat_completion_valid_request(llama_model): + """Test chat completion with a valid request.""" + req = ChatRequest( + model="bartowski/Llama-3.2-1B-Instruct-GGUF", + messages=[ + Message(role="system", content="You are a helpful assistant."), + Message(role="user", content="What is your name?"), + ], + ) + llama_model.chat_completion.return_value = RESPONSE + response = await llama_model.chat_completion(req) + assert isinstance(response, ChatResponse) + assert response.model == "test-model" + assert response.choices is not None + +@pytest.mark.asyncio +async def test_chat_completion_empty_messages(llama_model): + """Test chat completion with an empty messages list.""" + req = ChatRequest( + model="bartowski/Llama-3.2-1B-Instruct-GGUF", + messages=[], + ) + llama_model.chat_completion.side_effect = HTTPException(status_code=400, detail="The 'messages' field is required.") + with pytest.raises(HTTPException) as exc_info: + await llama_model.chat_completion(req) + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "The 'messages' field is required." + +@pytest.mark.asyncio +async def test_chat_completion_missing_model(llama_model): + """Test chat completion with a missing model field.""" + req = ChatRequest( + model="", + messages=[ + Message(role="system", content="You are a helpful assistant."), + Message(role="user", content="What is your name?"), + ], + ) + llama_model.chat_completion.side_effect = HTTPException(status_code=400, detail="The 'model' field is required.") + with pytest.raises(HTTPException) as exc_info: + await llama_model.chat_completion(req) + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "The 'model' field is required." diff --git a/tests/nilai_models/test_model.py b/tests/nilai_models/test_model.py new file mode 100644 index 00000000..4a69142b --- /dev/null +++ b/tests/nilai_models/test_model.py @@ -0,0 +1,52 @@ +from nilai_common.api_model import Choice, Message, Usage +import pytest +from fastapi.testclient import TestClient +from nilai_models.model import Model +from nilai_common import ModelMetadata, ChatRequest, ChatResponse, HealthCheckResponse, ModelEndpoint +from tests import model_metadata, response, model_endpoint + +class TestModel(Model): + async def chat_completion(self, req: ChatRequest) -> ChatResponse: + return response + +@pytest.fixture +def model_instance(): + metadata = model_metadata + return TestModel(metadata) + +@pytest.fixture +def client(model_instance): + return TestClient(model_instance.get_app()) + +def test_model_info(client): + response = client.get("/v1/models") + assert response.status_code == 200 + data = response.json() + assert data["id"] == "ABC" + assert data["name"] == "ABC" + assert data["description"] == "Description" + +def test_health_check(client): + response = client.get("/v1/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + assert "uptime" in data + +def test_chat_completion(client): + request = ChatRequest( + model="ABC", + messages=[ + Message(role="system", content="You are a helpful assistant"), + Message(role="user", content="Hello, who are you?") + ] + ) + response = client.post("/v1/chat/completions", json=request.model_dump()) + assert response.status_code == 200 + data = response.json() + assert "choices" in data + assert data["choices"][0]["finish_reason"] == "test-finish-reason" + assert data["usage"]["prompt_tokens"] == 100 + assert data["usage"]["completion_tokens"] == 50 + assert data["usage"]["total_tokens"] == 150 + assert data["signature"] == "test-signature" From b131af302e04d0abf60fd2149fcd01055a3f0673 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Mon, 16 Dec 2024 15:07:24 +0000 Subject: [PATCH 05/12] fix: corrections to test_model --- tests/nilai_models/test_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/nilai_models/test_model.py b/tests/nilai_models/test_model.py index 4a69142b..1aceabcc 100644 --- a/tests/nilai_models/test_model.py +++ b/tests/nilai_models/test_model.py @@ -5,14 +5,14 @@ from nilai_common import ModelMetadata, ChatRequest, ChatResponse, HealthCheckResponse, ModelEndpoint from tests import model_metadata, response, model_endpoint -class TestModel(Model): +class MyModel(Model): async def chat_completion(self, req: ChatRequest) -> ChatResponse: return response @pytest.fixture def model_instance(): metadata = model_metadata - return TestModel(metadata) + return MyModel(metadata) @pytest.fixture def client(model_instance): From 3196844feb543a2ff180a75232ababd3bac2e703 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Mon, 16 Dec 2024 15:15:37 +0000 Subject: [PATCH 06/12] fix: various fixes to linting and formatting --- tests/__init__.py | 15 +++++++++--- tests/nilai-common/test_api_model.py | 9 ++++--- tests/nilai-common/test_discovery.py | 24 ++++++++++++++----- tests/nilai_api/routers/test_private.py | 2 -- tests/nilai_api/routers/test_public.py | 18 +++++++------- .../nilai_models/models/test_llama_1b_cpu.py | 15 ++++++++---- tests/nilai_models/test_model.py | 16 +++++++++---- 7 files changed, 66 insertions(+), 33 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index ad0ef970..694d9aba 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,4 +1,11 @@ -from nilai_common import ChatResponse, Choice, Message, ModelEndpoint, ModelMetadata, Usage +from nilai_common import ( + ChatResponse, + Choice, + Message, + ModelEndpoint, + ModelMetadata, + Usage, +) model_metadata: ModelMetadata = ModelMetadata( id="ABC", # Unique identifier @@ -11,7 +18,9 @@ supported_features=["supported_feature"], # Capabilities ) -model_endpoint: ModelEndpoint = ModelEndpoint(url="http://test-model-url", metadata=model_metadata) +model_endpoint: ModelEndpoint = ModelEndpoint( + url="http://test-model-url", metadata=model_metadata +) response: ChatResponse = ChatResponse( id="test-id", @@ -28,4 +37,4 @@ ], usage=Usage(prompt_tokens=100, completion_tokens=50, total_tokens=150), signature="test-signature", -) \ No newline at end of file +) diff --git a/tests/nilai-common/test_api_model.py b/tests/nilai-common/test_api_model.py index fff3a9b5..ac59db6e 100644 --- a/tests/nilai-common/test_api_model.py +++ b/tests/nilai-common/test_api_model.py @@ -2,6 +2,7 @@ from pydantic import ValidationError from nilai_common.api_model import ModelMetadata + def test_model_metadata_creation(): """Test creating a ModelMetadata instance.""" metadata = ModelMetadata( @@ -11,7 +12,7 @@ def test_model_metadata_creation(): author="Test Author", license="MIT", source="https://example.com", - supported_features=["feature1", "feature2"] + supported_features=["feature1", "feature2"], ) assert metadata.id is not None @@ -23,6 +24,7 @@ def test_model_metadata_creation(): assert metadata.source == "https://example.com" assert metadata.supported_features == ["feature1", "feature2"] + def test_model_metadata_default_id(): """Test that ModelMetadata generates a default UUID for id.""" metadata = ModelMetadata( @@ -32,12 +34,13 @@ def test_model_metadata_default_id(): author="Test Author", license="MIT", source="https://example.com", - supported_features=["feature1", "feature2"] + supported_features=["feature1", "feature2"], ) assert metadata.id is not None assert len(metadata.id) == 36 # UUID length + def test_model_metadata_invalid_data(): """Test creating ModelMetadata with invalid data.""" with pytest.raises(ValidationError): @@ -48,4 +51,4 @@ def test_model_metadata_invalid_data(): author="", license="", source="", - ) # type: ignore \ No newline at end of file + ) # type: ignore diff --git a/tests/nilai-common/test_discovery.py b/tests/nilai-common/test_discovery.py index 0e36bb03..d1714f5f 100644 --- a/tests/nilai-common/test_discovery.py +++ b/tests/nilai-common/test_discovery.py @@ -4,14 +4,16 @@ from nilai_common.discovery import ModelServiceDiscovery from nilai_common.api_model import ModelEndpoint, ModelMetadata + @pytest.fixture def model_service_discovery(): - with patch('nilai_common.discovery.Etcd3Client') as MockClient: + with patch("nilai_common.discovery.Etcd3Client") as MockClient: mock_client = MockClient.return_value discovery = ModelServiceDiscovery() discovery.client = mock_client yield discovery + @pytest.fixture def model_endpoint(): model_metadata = ModelMetadata( @@ -27,6 +29,7 @@ def model_endpoint(): url="http://test-model-service.example.com/predict", metadata=model_metadata ) + @pytest.mark.asyncio async def test_register_model(model_service_discovery, model_endpoint): lease_mock = MagicMock() @@ -37,14 +40,15 @@ async def test_register_model(model_service_discovery, model_endpoint): model_service_discovery.client.put.assert_called_once_with( f"/models/{model_endpoint.metadata.id}", model_endpoint.model_dump_json(), - lease=lease_mock + lease=lease_mock, ) assert lease == lease_mock + @pytest.mark.asyncio async def test_discover_models(model_service_discovery, model_endpoint): model_service_discovery.client.get_prefix.return_value = [ - (model_endpoint.model_dump_json().encode('utf-8'), None) + (model_endpoint.model_dump_json().encode("utf-8"), None) ] discovered_models = await model_service_discovery.discover_models() @@ -53,19 +57,27 @@ async def test_discover_models(model_service_discovery, model_endpoint): assert model_endpoint.metadata.id in discovered_models assert discovered_models[model_endpoint.metadata.id] == model_endpoint + @pytest.mark.asyncio async def test_get_model(model_service_discovery, model_endpoint): - model_service_discovery.client.get.return_value = (model_endpoint.model_dump_json().encode('utf-8'), None) + model_service_discovery.client.get.return_value = ( + model_endpoint.model_dump_json().encode("utf-8"), + None, + ) model = await model_service_discovery.get_model(model_endpoint.metadata.id) assert model == model_endpoint + @pytest.mark.asyncio async def test_unregister_model(model_service_discovery, model_endpoint): await model_service_discovery.unregister_model(model_endpoint.metadata.id) - model_service_discovery.client.delete.assert_called_once_with(f"/models/{model_endpoint.metadata.id}") + model_service_discovery.client.delete.assert_called_once_with( + f"/models/{model_endpoint.metadata.id}" + ) + @pytest.mark.asyncio async def test_keep_alive(model_service_discovery): @@ -79,4 +91,4 @@ async def keep_alive_task(): await asyncio.sleep(0.1) task.cancel() - lease_mock.refresh.assert_called() \ No newline at end of file + lease_mock.refresh.assert_called() diff --git a/tests/nilai_api/routers/test_private.py b/tests/nilai_api/routers/test_private.py index fea013ea..7adc31f9 100644 --- a/tests/nilai_api/routers/test_private.py +++ b/tests/nilai_api/routers/test_private.py @@ -3,8 +3,6 @@ ChatResponse, Choice, Message, - ModelEndpoint, - ModelMetadata, Usage, ) import pytest diff --git a/tests/nilai_api/routers/test_public.py b/tests/nilai_api/routers/test_public.py index 8f308f9b..3cb9c8c4 100644 --- a/tests/nilai_api/routers/test_public.py +++ b/tests/nilai_api/routers/test_public.py @@ -1,8 +1,5 @@ -import pytest from fastapi.testclient import TestClient from nilai_api.routers.public import router -from nilai_api.state import state -from nilai_common import HealthCheckResponse from fastapi import FastAPI app = FastAPI() @@ -10,11 +7,12 @@ client = TestClient(app) + def test_health_check(): - """Test the health check endpoint.""" - response = client.get("/v1/health") - assert response.status_code == 200 - data = response.json() - assert data["status"] == "ok" - assert "uptime" in data - assert isinstance(data["uptime"], str) \ No newline at end of file + """Test the health check endpoint.""" + response = client.get("/v1/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert "uptime" in data + assert isinstance(data["uptime"], str) diff --git a/tests/nilai_models/models/test_llama_1b_cpu.py b/tests/nilai_models/models/test_llama_1b_cpu.py index e1d526f4..14cf554e 100644 --- a/tests/nilai_models/models/test_llama_1b_cpu.py +++ b/tests/nilai_models/models/test_llama_1b_cpu.py @@ -4,16 +4,17 @@ from nilai_models.models.llama_1b_cpu.llama_1b_cpu import Llama1BCpu from nilai_common import ChatRequest, Message, ChatResponse -from tests import model_metadata, model_endpoint, response as RESPONSE +from tests import response as RESPONSE @pytest.fixture def llama_model(mocker): """Fixture to provide a Llama1BCpu instance for testing.""" model = Llama1BCpu() - mocker.patch.object(model, 'chat_completion', new_callable=AsyncMock) + mocker.patch.object(model, "chat_completion", new_callable=AsyncMock) return model + @pytest.mark.asyncio async def test_chat_completion_valid_request(llama_model): """Test chat completion with a valid request.""" @@ -30,6 +31,7 @@ async def test_chat_completion_valid_request(llama_model): assert response.model == "test-model" assert response.choices is not None + @pytest.mark.asyncio async def test_chat_completion_empty_messages(llama_model): """Test chat completion with an empty messages list.""" @@ -37,12 +39,15 @@ async def test_chat_completion_empty_messages(llama_model): model="bartowski/Llama-3.2-1B-Instruct-GGUF", messages=[], ) - llama_model.chat_completion.side_effect = HTTPException(status_code=400, detail="The 'messages' field is required.") + llama_model.chat_completion.side_effect = HTTPException( + status_code=400, detail="The 'messages' field is required." + ) with pytest.raises(HTTPException) as exc_info: await llama_model.chat_completion(req) assert exc_info.value.status_code == 400 assert exc_info.value.detail == "The 'messages' field is required." + @pytest.mark.asyncio async def test_chat_completion_missing_model(llama_model): """Test chat completion with a missing model field.""" @@ -53,7 +58,9 @@ async def test_chat_completion_missing_model(llama_model): Message(role="user", content="What is your name?"), ], ) - llama_model.chat_completion.side_effect = HTTPException(status_code=400, detail="The 'model' field is required.") + llama_model.chat_completion.side_effect = HTTPException( + status_code=400, detail="The 'model' field is required." + ) with pytest.raises(HTTPException) as exc_info: await llama_model.chat_completion(req) assert exc_info.value.status_code == 400 diff --git a/tests/nilai_models/test_model.py b/tests/nilai_models/test_model.py index 1aceabcc..36ea9259 100644 --- a/tests/nilai_models/test_model.py +++ b/tests/nilai_models/test_model.py @@ -1,23 +1,27 @@ -from nilai_common.api_model import Choice, Message, Usage +from nilai_common.api_model import Message import pytest from fastapi.testclient import TestClient from nilai_models.model import Model -from nilai_common import ModelMetadata, ChatRequest, ChatResponse, HealthCheckResponse, ModelEndpoint -from tests import model_metadata, response, model_endpoint +from nilai_common import ChatRequest, ChatResponse +from tests import model_metadata, response + class MyModel(Model): async def chat_completion(self, req: ChatRequest) -> ChatResponse: return response + @pytest.fixture def model_instance(): metadata = model_metadata return MyModel(metadata) + @pytest.fixture def client(model_instance): return TestClient(model_instance.get_app()) + def test_model_info(client): response = client.get("/v1/models") assert response.status_code == 200 @@ -26,6 +30,7 @@ def test_model_info(client): assert data["name"] == "ABC" assert data["description"] == "Description" + def test_health_check(client): response = client.get("/v1/health") assert response.status_code == 200 @@ -33,13 +38,14 @@ def test_health_check(client): assert data["status"] == "healthy" assert "uptime" in data + def test_chat_completion(client): request = ChatRequest( model="ABC", messages=[ Message(role="system", content="You are a helpful assistant"), - Message(role="user", content="Hello, who are you?") - ] + Message(role="user", content="Hello, who are you?"), + ], ) response = client.post("/v1/chat/completions", json=request.model_dump()) assert response.status_code == 200 From 7610eb46527a444d080cc1edd0eeabd79db251c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Mon, 16 Dec 2024 15:16:20 +0000 Subject: [PATCH 07/12] feat: added github actions workflow --- .github/workflows/ci.yml | 43 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 .github/workflows/ci.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..266ea3d7 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,43 @@ +name: Python Tests + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main"] # Adjust branches as needed + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' # Adjust based on your requirements + + - name: Install uv + uses: astral-sh/setup-uv@v4 + + - name: Cache dependencies + uses: actions/cache@v3 + with: + path: ~/.cache/uv + key: ${{ runner.os }}-uv-${{ hashFiles('**/pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-uv- + + - name: Install dependencies + run: | + uv sync + + - name: Run Ruff format check + run: uv run ruff format --check + + - name: Run Ruff linting + run: uv run ruff check + + - name: Run tests + run: uv run pytest -v \ No newline at end of file From 4ffbbce4b48025022d3f03a19be1f2a4f4bbeb62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Mon, 16 Dec 2024 15:20:27 +0000 Subject: [PATCH 08/12] feat: added updated dependencies cache --- .github/workflows/ci.yml | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 266ea3d7..5fbdf0a8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,7 +32,14 @@ jobs: - name: Install dependencies run: | uv sync - + + - name: Update dependencies cache + if: always() # Ensures cache is updated even if the install fails + uses: actions/cache@v3 + with: + path: ~/.cache/uv + key: ${{ runner.os }}-uv-${{ hashFiles('**/pyproject.toml') }} + - name: Run Ruff format check run: uv run ruff format --check From bae9afaf44aba4f43198b0b3c6ea61d7ed1b6001 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Mon, 16 Dec 2024 15:33:41 +0000 Subject: [PATCH 09/12] feat: added sev tests --- nilai-api/src/nilai_api/sev/sev.py | 198 +++++++++++++---------------- nilai-api/src/nilai_api/state.py | 6 +- tests/nilai_api/sev/__init__.py | 0 tests/nilai_api/sev/test_sev.py | 63 +++++++++ tests/nilai_api/test_state.py | 9 +- 5 files changed, 159 insertions(+), 117 deletions(-) create mode 100644 tests/nilai_api/sev/__init__.py create mode 100644 tests/nilai_api/sev/test_sev.py diff --git a/nilai-api/src/nilai_api/sev/sev.py b/nilai-api/src/nilai_api/sev/sev.py index d688bc54..c9090b00 100644 --- a/nilai-api/src/nilai_api/sev/sev.py +++ b/nilai-api/src/nilai_api/sev/sev.py @@ -1,116 +1,100 @@ import base64 import ctypes import os -from ctypes import c_char_p, c_int, create_string_buffer +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + + +class SEVGuest: + def __init__(self): + self.lib: Optional[ctypes.CDLL] = None + self._load_library() + + def _load_library(self) -> None: + try: + lib_path = f"{os.path.dirname(os.path.abspath(__file__))}/libsevguest.so" + if not os.path.exists(lib_path): + logger.warning(f"SEV library not found at {lib_path}") + return + + self.lib = ctypes.CDLL(lib_path) + self._setup_library_functions() + except Exception as e: + logger.warning(f"Failed to load SEV library: {e}") + self.lib = None + + def _setup_library_functions(self) -> None: + if not self.lib: + return + + self.lib.OpenDevice.restype = ctypes.c_int + self.lib.GetQuoteProvider.restype = ctypes.c_int + self.lib.Init.restype = ctypes.c_int + self.lib.GetQuote.restype = ctypes.c_char_p + self.lib.GetQuote.argtypes = [ctypes.c_char_p] + self.lib.VerifyQuote.restype = ctypes.c_int + self.lib.VerifyQuote.argtypes = [ctypes.c_char_p] + self.lib.free.argtypes = [ctypes.c_char_p] + + def init(self) -> bool: + """Initialize the device and quote provider.""" + if not self.lib: + logger.warning("SEV library not loaded, running in mock mode") + return True + return self.lib.Init() == 0 + + def get_quote(self, report_data: Optional[bytes] = None) -> str: + """Get a quote using the report data.""" + if not self.lib: + logger.warning("SEV library not loaded, returning mock quote") + return base64.b64encode(b"mock_quote").decode("ascii") + + if report_data is None: + report_data = bytes(64) + + if len(report_data) != 64: + raise ValueError("Report data must be exactly 64 bytes") + + report_buffer = ctypes.create_string_buffer(report_data) + quote_ptr = self.lib.GetQuote(report_buffer) + + if quote_ptr is None: + raise RuntimeError("Failed to get quote") + + quote_str = ctypes.string_at(quote_ptr) + return base64.b64encode(quote_str).decode("ascii") + + def verify_quote(self, quote: str) -> bool: + """Verify the quote using the library's verification method.""" + if not self.lib: + logger.warning( + "SEV library not loaded, mock verification always returns True" + ) + return True + + quote_bytes = base64.b64decode(quote.encode("ascii")) + quote_buffer = ctypes.create_string_buffer(quote_bytes) + return self.lib.VerifyQuote(quote_buffer) == 0 + + +# Global instance +sev = SEVGuest() -# Load the shared library -lib = ctypes.CDLL(f"{os.path.dirname(os.path.abspath(__file__))}/libsevguest.so") - -# OpenDevice -lib.OpenDevice.restype = c_int - -# GetQuoteProvider -lib.GetQuoteProvider.restype = c_int - -# Init -lib.Init.restype = c_int - -# GetQuote -lib.GetQuote.restype = c_char_p -lib.GetQuote.argtypes = [c_char_p] - -# VerifyQuote -lib.VerifyQuote.restype = c_int -lib.VerifyQuote.argtypes = [c_char_p] - -lib.free.argtypes = [c_char_p] - - -# Python wrapper functions -def init(): - """Initialize the device and quote provider.""" - if lib.Init() != 0: - raise RuntimeError("Failed to initialize SEV guest device and quote provider.") - - -def get_quote(report_data=None) -> str: - """ - Get a quote using the report data. - - Args: - report_data (bytes, optional): 64-byte report data. - Defaults to 64 zero bytes. - - Returns: - str: The quote as a string - """ - # Use 64 zero bytes if no report data provided - if report_data is None: - report_data = bytes(64) - - # Validate report data - if len(report_data) != 64: - raise ValueError("Report data must be exactly 64 bytes") - - # Create a buffer from the report data - report_buffer = create_string_buffer(report_data) - - # Get the quote - quote_ptr = lib.GetQuote(report_buffer) - quote_str = ctypes.string_at(quote_ptr) - - # We should be freeing the quote, but it turns out it raises an error. - # lib.free(quote_ptr) - # Check if quote retrieval failed - if quote_ptr is None: - raise RuntimeError("Failed to get quote") - - # Convert quote to Python string - quote = base64.b64encode(quote_str) - return quote.decode("ascii") - - -def verify_quote(quote: str) -> bool: - """ - Verify the quote using the library's verification method. - - Args: - quote (str): The quote to verify - - Returns: - bool: True if quote is verified, False otherwise - """ - # Ensure quote is a string - if not isinstance(quote, str): - quote = str(quote) - - # Convert to bytes - quote_bytes = base64.b64decode(quote.encode("ascii")) - quote_buffer = create_string_buffer(quote_bytes) - - # Verify quote - result = lib.VerifyQuote(quote_buffer) - return result == 0 - - -# Example usage if __name__ == "__main__": try: - # Initialize the device and quote provider - init() - print("SEV guest device initialized successfully.") - - # Create a 64-byte report data array (all zeros for simplicity) - report_data = bytes([0] * 64) - - # Get the quote - quote = get_quote(report_data) - print(type(quote)) - print("Quote:", quote) - - if verify_quote(quote): - print("Quote verified successfully.") + if sev.init(): + print("SEV guest device initialized successfully.") + report_data = bytes([0] * 64) + quote = sev.get_quote(report_data) + print("Quote:", quote) + + if sev.verify_quote(quote): + print("Quote verified successfully.") + else: + print("Quote verification failed.") else: - print("Quote verification failed.") + print("Failed to initialize SEV guest device.") except Exception as e: print("Error:", e) diff --git a/nilai-api/src/nilai_api/state.py b/nilai-api/src/nilai_api/state.py index 5795cf10..e1680a92 100644 --- a/nilai-api/src/nilai_api/state.py +++ b/nilai-api/src/nilai_api/state.py @@ -5,7 +5,7 @@ from dotenv import load_dotenv from nilai_api.crypto import generate_key_pair -from nilai_api.sev.sev import get_quote, init +from nilai_api.sev.sev import sev from nilai_common import ModelServiceDiscovery, SETTINGS from nilai_common.api_model import ModelEndpoint @@ -28,8 +28,8 @@ def __init__(self): def cpu_attestation(self) -> str: if self._cpu_quote is None: try: - init() - self._cpu_quote = get_quote() + sev.init() + self._cpu_quote = sev.get_quote() except RuntimeError: self._cpu_quote = "" return self._cpu_quote diff --git a/tests/nilai_api/sev/__init__.py b/tests/nilai_api/sev/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/nilai_api/sev/test_sev.py b/tests/nilai_api/sev/test_sev.py new file mode 100644 index 00000000..4ece3dfa --- /dev/null +++ b/tests/nilai_api/sev/test_sev.py @@ -0,0 +1,63 @@ +import base64 +import ctypes +import pytest +from nilai_api.sev.sev import SEVGuest + + +@pytest.fixture +def sev_guest(): + return SEVGuest() + + +def test_init_success(sev_guest, mocker): + mocker.patch.object(sev_guest, "_load_library", return_value=None) + sev_guest.lib = mocker.Mock() + sev_guest.lib.Init.return_value = 0 + assert sev_guest.init() is True + + +def test_init_failure(sev_guest, mocker): + mocker.patch.object(sev_guest, "_load_library", return_value=None) + sev_guest.lib = mocker.Mock() + sev_guest.lib.Init.return_value = -1 + assert sev_guest.init() is False + + +def test_get_quote_success(sev_guest, mocker): + mocker.patch.object(sev_guest, "_load_library", return_value=None) + sev_guest.lib = mocker.Mock() + sev_guest.lib.GetQuote.return_value = ctypes.create_string_buffer(b"quote_data") + report_data = bytes([0] * 64) + quote = sev_guest.get_quote(report_data) + expected_quote = base64.b64encode(b"quote_data").decode("ascii") + assert quote == expected_quote + + +def test_get_quote_failure(sev_guest, mocker): + mocker.patch.object(sev_guest, "_load_library", return_value=None) + sev_guest.lib = mocker.Mock() + sev_guest.lib.GetQuote.return_value = None + report_data = bytes([0] * 64) + with pytest.raises(RuntimeError): + sev_guest.get_quote(report_data) + + +def test_get_quote_invalid_report_data(sev_guest): + with pytest.raises(ValueError): + sev_guest.get_quote(bytes([0] * 63)) + + +def test_verify_quote_success(sev_guest, mocker): + mocker.patch.object(sev_guest, "_load_library", return_value=None) + sev_guest.lib = mocker.Mock() + sev_guest.lib.VerifyQuote.return_value = 0 + quote = base64.b64encode(b"quote_data").decode("ascii") + assert sev_guest.verify_quote(quote) is True + + +def test_verify_quote_failure(sev_guest, mocker): + mocker.patch.object(sev_guest, "_load_library", return_value=None) + sev_guest.lib = mocker.Mock() + sev_guest.lib.VerifyQuote.return_value = -1 + quote = base64.b64encode(b"quote_data").decode("ascii") + assert sev_guest.verify_quote(quote) is False diff --git a/tests/nilai_api/test_state.py b/tests/nilai_api/test_state.py index 1325e79f..f8f6c689 100644 --- a/tests/nilai_api/test_state.py +++ b/tests/nilai_api/test_state.py @@ -28,19 +28,14 @@ def test_uptime(app_state): ) -@patch("nilai_api.state.init") -@patch("nilai_api.state.get_quote", return_value="mocked_quote") +@patch("nilai_api.state.sev.init") +@patch("nilai_api.state.sev.get_quote", return_value="mocked_quote") def test_cpu_attestation(mock_get_quote, mock_init, app_state): assert app_state.cpu_attestation == "mocked_quote" mock_init.assert_called_once() mock_get_quote.assert_called_once() -def test_cpu_attestation_non_tee(app_state): - with patch("nilai_api.state.init", side_effect=RuntimeError): - assert app_state.cpu_attestation == "" - - def test_gpu_attestation(app_state): assert app_state.gpu_attestation == "" From 9445e24508fce7b7efe0d33310ec0b4eed49929e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Mon, 16 Dec 2024 15:42:24 +0000 Subject: [PATCH 10/12] fix: more fixes --- .github/workflows/ci.yml | 24 +++--------------------- tests/nilai_api/sev/test_sev.py | 5 +++-- 2 files changed, 6 insertions(+), 23 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5fbdf0a8..9d3ed0d3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,32 +13,14 @@ jobs: steps: - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v4 + - uses: astral-sh/setup-uv@v4 with: - python-version: '3.11' # Adjust based on your requirements - - - name: Install uv - uses: astral-sh/setup-uv@v4 - - - name: Cache dependencies - uses: actions/cache@v3 - with: - path: ~/.cache/uv - key: ${{ runner.os }}-uv-${{ hashFiles('**/pyproject.toml') }} - restore-keys: | - ${{ runner.os }}-uv- + enable-cache: true + cache-dependency-glob: "**/pyproject.toml" - name: Install dependencies run: | uv sync - - - name: Update dependencies cache - if: always() # Ensures cache is updated even if the install fails - uses: actions/cache@v3 - with: - path: ~/.cache/uv - key: ${{ runner.os }}-uv-${{ hashFiles('**/pyproject.toml') }} - name: Run Ruff format check run: uv run ruff format --check diff --git a/tests/nilai_api/sev/test_sev.py b/tests/nilai_api/sev/test_sev.py index 4ece3dfa..545559bd 100644 --- a/tests/nilai_api/sev/test_sev.py +++ b/tests/nilai_api/sev/test_sev.py @@ -43,8 +43,9 @@ def test_get_quote_failure(sev_guest, mocker): def test_get_quote_invalid_report_data(sev_guest): - with pytest.raises(ValueError): - sev_guest.get_quote(bytes([0] * 63)) + if sev_guest.lib is not None: + with pytest.raises(ValueError): + sev_guest.get_quote(bytes([0] * 63)) def test_verify_quote_success(sev_guest, mocker): From b5a5d9807984006b2d7220282cbb489b572452d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Tue, 17 Dec 2024 08:48:57 +0000 Subject: [PATCH 11/12] fix: added cache --- .github/workflows/ci.yml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9d3ed0d3..b4721c41 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,6 +18,14 @@ jobs: enable-cache: true cache-dependency-glob: "**/pyproject.toml" + - name: Cache dependencies + uses: actions/cache@v3 + with: + path: ~/.cache/uv + key: ${{ runner.os }}-uv-${{ hashFiles('**/pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-uv- + - name: Install dependencies run: | uv sync From c83d5ed914f833a2ae922c9f43a7aa639e33908c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Tue, 17 Dec 2024 08:57:52 +0000 Subject: [PATCH 12/12] feat: modified uv cache dir --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b4721c41..1ddbc6a3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,11 +17,11 @@ jobs: with: enable-cache: true cache-dependency-glob: "**/pyproject.toml" - + - name: Cache dependencies uses: actions/cache@v3 with: - path: ~/.cache/uv + path: ${{ env.UV_CACHE_DIR }} key: ${{ runner.os }}-uv-${{ hashFiles('**/pyproject.toml') }} restore-keys: | ${{ runner.os }}-uv-