From 4072d925d0b6726a7efee46a079421f4d55bcd0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Wed, 18 Dec 2024 13:01:53 +0000 Subject: [PATCH 1/3] feat: Added maximum supported token windows for each model --- .github/workflows/ci.yml | 8 -------- nilai-api/src/nilai_api/routers/private.py | 4 ++-- .../src/nilai_models/models/llama_1b_cpu/llama_1b_cpu.py | 2 +- .../src/nilai_models/models/llama_8b_cpu/llama_8b_cpu.py | 2 +- nilai-models/src/nilai_models/models/llama_model.py | 8 ++++++-- .../models/secret_llama_1b_cpu/secret_llama_1b_cpu.py | 4 ++-- packages/nilai-common/src/nilai_common/api_model.py | 2 +- 7 files changed, 13 insertions(+), 17 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 13ca532e..18204519 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,14 +18,6 @@ jobs: enable-cache: true cache-dependency-glob: "**/pyproject.toml" - - name: Cache dependencies - uses: actions/cache@v3 - with: - path: ${{ env.UV_CACHE_DIR }} - key: ${{ runner.os }}-uv-${{ hashFiles('**/pyproject.toml') }} - restore-keys: | - ${{ runner.os }}-uv- - - name: Install dependencies run: | uv sync diff --git a/nilai-api/src/nilai_api/routers/private.py b/nilai-api/src/nilai_api/routers/private.py index 33932e10..1d8c740b 100644 --- a/nilai-api/src/nilai_api/routers/private.py +++ b/nilai-api/src/nilai_api/routers/private.py @@ -160,7 +160,7 @@ async def stream_response() -> AsyncGenerator[str, None]: "POST", f"{model_url}/v1/chat/completions", json=req.model_dump(), - timeout=60.0, + timeout=None, ) as response: response.raise_for_status() # Raise an error for invalid status codes @@ -188,7 +188,7 @@ async def stream_response() -> AsyncGenerator[str, None]: try: async with httpx.AsyncClient() as client: response = await client.post( - f"{model_url}/v1/chat/completions", json=req.model_dump(), timeout=60.0 + f"{model_url}/v1/chat/completions", json=req.model_dump(), timeout=None ) response.raise_for_status() model_response = ChatResponse.model_validate_json(response.content) diff --git a/nilai-models/src/nilai_models/models/llama_1b_cpu/llama_1b_cpu.py b/nilai-models/src/nilai_models/models/llama_1b_cpu/llama_1b_cpu.py index d807360d..2ddca7ac 100644 --- a/nilai-models/src/nilai_models/models/llama_1b_cpu/llama_1b_cpu.py +++ b/nilai-models/src/nilai_models/models/llama_1b_cpu/llama_1b_cpu.py @@ -26,7 +26,7 @@ def __init__(self): repo_id="bartowski/Llama-3.2-1B-Instruct-GGUF", filename="Llama-3.2-1B-Instruct-Q5_K_S.gguf", n_threads=16, - n_ctx=2048, + n_ctx=128 * 1024, verbose=False, ), metadata=ModelMetadata( diff --git a/nilai-models/src/nilai_models/models/llama_8b_cpu/llama_8b_cpu.py b/nilai-models/src/nilai_models/models/llama_8b_cpu/llama_8b_cpu.py index 54d25802..b6b64314 100644 --- a/nilai-models/src/nilai_models/models/llama_8b_cpu/llama_8b_cpu.py +++ b/nilai-models/src/nilai_models/models/llama_8b_cpu/llama_8b_cpu.py @@ -37,7 +37,7 @@ def __init__(self): repo_id="bartowski/Meta-Llama-3-8B-Instruct-GGUF", filename="Meta-Llama-3-8B-Instruct-Q5_K_M.gguf", n_threads=16, - n_ctx=2048, + n_ctx=8 * 1024, verbose=False, ), metadata=ModelMetadata( diff --git a/nilai-models/src/nilai_models/models/llama_model.py b/nilai-models/src/nilai_models/models/llama_model.py index cdaaf939..0b128d61 100644 --- a/nilai-models/src/nilai_models/models/llama_model.py +++ b/nilai-models/src/nilai_models/models/llama_model.py @@ -78,7 +78,7 @@ def generate() -> Generator[str, Any, None]: prompt, # type: ignore stream=True, temperature=req.temperature if req.temperature else 0.2, - max_tokens=req.max_tokens, + max_tokens=req.max_tokens if req.max_tokens else 2048, ): # Extract delta content from output choices = output.get("choices", []) # type: ignore @@ -103,7 +103,11 @@ def generate() -> Generator[str, Any, None]: # Non-streaming (regular) chat completion try: - generation: dict = self.model.create_chat_completion(prompt) # type: ignore + generation: dict = self.model.create_chat_completion( + prompt, + temperature=req.temperature if req.temperature else 0.2, + max_tokens=req.max_tokens if req.max_tokens else 2048, + ) # type: ignore except ValueError: raise HTTPException( status_code=400, diff --git a/nilai-models/src/nilai_models/models/secret_llama_1b_cpu/secret_llama_1b_cpu.py b/nilai-models/src/nilai_models/models/secret_llama_1b_cpu/secret_llama_1b_cpu.py index 871c6245..358b7070 100644 --- a/nilai-models/src/nilai_models/models/secret_llama_1b_cpu/secret_llama_1b_cpu.py +++ b/nilai-models/src/nilai_models/models/secret_llama_1b_cpu/secret_llama_1b_cpu.py @@ -27,7 +27,7 @@ def __init__(self): repo_id="bartowski/Llama-3.2-1B-Instruct-GGUF", filename="Llama-3.2-1B-Instruct-Q5_K_S.gguf", n_threads=16, - n_ctx=2048, + n_ctx=128 * 1024, verbose=False, ), metadata=ModelMetadata( @@ -40,7 +40,7 @@ def __init__(self): source="https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF", # Model source supported_features=["chat_completion"], # Capabilities ), - prefix="d01fe399-8dc2-4c74-acde-ff649802f437", + prefix="nillion/", ) async def chat_completion( diff --git a/packages/nilai-common/src/nilai_common/api_model.py b/packages/nilai-common/src/nilai_common/api_model.py index c4687b46..ef374700 100644 --- a/packages/nilai-common/src/nilai_common/api_model.py +++ b/packages/nilai-common/src/nilai_common/api_model.py @@ -13,7 +13,7 @@ class ChatRequest(BaseModel): model: str messages: List[Message] temperature: Optional[float] = 0.2 - max_tokens: Optional[int] = None + max_tokens: Optional[int] = 2048 stream: Optional[bool] = False From 783f8a395b6b88193f5db64c019502baf294fb19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Wed, 18 Dec 2024 14:08:14 +0000 Subject: [PATCH 2/3] feat: correction to test llama_1b --- .github/workflows/ci.yml | 8 ++++++++ tests/nilai_models/models/test_llama_1b_cpu.py | 3 ++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 18204519..13ca532e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,6 +18,14 @@ jobs: enable-cache: true cache-dependency-glob: "**/pyproject.toml" + - name: Cache dependencies + uses: actions/cache@v3 + with: + path: ${{ env.UV_CACHE_DIR }} + key: ${{ runner.os }}-uv-${{ hashFiles('**/pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-uv- + - name: Install dependencies run: | uv sync diff --git a/tests/nilai_models/models/test_llama_1b_cpu.py b/tests/nilai_models/models/test_llama_1b_cpu.py index d1a8a1d3..0731d073 100644 --- a/tests/nilai_models/models/test_llama_1b_cpu.py +++ b/tests/nilai_models/models/test_llama_1b_cpu.py @@ -7,11 +7,12 @@ from tests import response as RESPONSE +model = Llama1BCpu() + @pytest.fixture def llama_model(mocker): """Fixture to provide a Llama1BCpu instance for testing.""" - model = Llama1BCpu() mocker.patch.object(model, "chat_completion", new_callable=AsyncMock) return model From 0a2fb770a82dcfcede2a8d1f096ab3e825e9bda3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Wed, 18 Dec 2024 15:24:18 +0000 Subject: [PATCH 3/3] Merge linting and formating changes --- nilai-models/src/nilai_models/models/llama_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nilai-models/src/nilai_models/models/llama_model.py b/nilai-models/src/nilai_models/models/llama_model.py index 93b04403..4f969e99 100644 --- a/nilai-models/src/nilai_models/models/llama_model.py +++ b/nilai-models/src/nilai_models/models/llama_model.py @@ -86,7 +86,6 @@ async def generate() -> AsyncGenerator[str, None]: ), ) for output in output_generator: - # Extract delta content from output choices = output.get("choices", []) # type: ignore if not choices or "delta" not in choices[0]: