diff --git a/nilai-api/src/nilai_api/routers/private.py b/nilai-api/src/nilai_api/routers/private.py index d09e88b9..11aba0bb 100644 --- a/nilai-api/src/nilai_api/routers/private.py +++ b/nilai-api/src/nilai_api/routers/private.py @@ -161,7 +161,7 @@ async def stream_response() -> AsyncGenerator[str, None]: "POST", f"{model_url}/v1/chat/completions", json=req.model_dump(), - timeout=60.0, + timeout=None, ) as response: response.raise_for_status() # Raise an error for invalid status codes @@ -192,7 +192,7 @@ async def stream_response() -> AsyncGenerator[str, None]: try: async with httpx.AsyncClient() as client: response = await client.post( - f"{model_url}/v1/chat/completions", json=req.model_dump(), timeout=60.0 + f"{model_url}/v1/chat/completions", json=req.model_dump(), timeout=None ) response.raise_for_status() model_response = ChatResponse.model_validate_json(response.content) diff --git a/nilai-models/src/nilai_models/models/llama_1b_cpu/llama_1b_cpu.py b/nilai-models/src/nilai_models/models/llama_1b_cpu/llama_1b_cpu.py index d807360d..2ddca7ac 100644 --- a/nilai-models/src/nilai_models/models/llama_1b_cpu/llama_1b_cpu.py +++ b/nilai-models/src/nilai_models/models/llama_1b_cpu/llama_1b_cpu.py @@ -26,7 +26,7 @@ def __init__(self): repo_id="bartowski/Llama-3.2-1B-Instruct-GGUF", filename="Llama-3.2-1B-Instruct-Q5_K_S.gguf", n_threads=16, - n_ctx=2048, + n_ctx=128 * 1024, verbose=False, ), metadata=ModelMetadata( diff --git a/nilai-models/src/nilai_models/models/llama_8b_cpu/llama_8b_cpu.py b/nilai-models/src/nilai_models/models/llama_8b_cpu/llama_8b_cpu.py index 54d25802..b6b64314 100644 --- a/nilai-models/src/nilai_models/models/llama_8b_cpu/llama_8b_cpu.py +++ b/nilai-models/src/nilai_models/models/llama_8b_cpu/llama_8b_cpu.py @@ -37,7 +37,7 @@ def __init__(self): repo_id="bartowski/Meta-Llama-3-8B-Instruct-GGUF", filename="Meta-Llama-3-8B-Instruct-Q5_K_M.gguf", n_threads=16, - n_ctx=2048, + n_ctx=8 * 1024, verbose=False, ), metadata=ModelMetadata( diff --git a/nilai-models/src/nilai_models/models/llama_model.py b/nilai-models/src/nilai_models/models/llama_model.py index 418a74e4..4f969e99 100644 --- a/nilai-models/src/nilai_models/models/llama_model.py +++ b/nilai-models/src/nilai_models/models/llama_model.py @@ -82,7 +82,7 @@ async def generate() -> AsyncGenerator[str, None]: prompt, # type: ignore stream=True, temperature=req.temperature if req.temperature else 0.2, - max_tokens=req.max_tokens, + max_tokens=req.max_tokens if req.max_tokens else 2048, ), ) for output in output_generator: @@ -110,7 +110,11 @@ async def generate() -> AsyncGenerator[str, None]: # Non-streaming (regular) chat completion try: - generation: dict = self.model.create_chat_completion(prompt) # type: ignore + generation: dict = self.model.create_chat_completion( + prompt, + temperature=req.temperature if req.temperature else 0.2, + max_tokens=req.max_tokens if req.max_tokens else 2048, + ) # type: ignore except ValueError: raise HTTPException( status_code=400, diff --git a/nilai-models/src/nilai_models/models/secret_llama_1b_cpu/secret_llama_1b_cpu.py b/nilai-models/src/nilai_models/models/secret_llama_1b_cpu/secret_llama_1b_cpu.py index 871c6245..358b7070 100644 --- a/nilai-models/src/nilai_models/models/secret_llama_1b_cpu/secret_llama_1b_cpu.py +++ b/nilai-models/src/nilai_models/models/secret_llama_1b_cpu/secret_llama_1b_cpu.py @@ -27,7 +27,7 @@ def __init__(self): repo_id="bartowski/Llama-3.2-1B-Instruct-GGUF", filename="Llama-3.2-1B-Instruct-Q5_K_S.gguf", n_threads=16, - n_ctx=2048, + n_ctx=128 * 1024, verbose=False, ), metadata=ModelMetadata( @@ -40,7 +40,7 @@ def __init__(self): source="https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF", # Model source supported_features=["chat_completion"], # Capabilities ), - prefix="d01fe399-8dc2-4c74-acde-ff649802f437", + prefix="nillion/", ) async def chat_completion( diff --git a/packages/nilai-common/src/nilai_common/api_model.py b/packages/nilai-common/src/nilai_common/api_model.py index c4687b46..ef374700 100644 --- a/packages/nilai-common/src/nilai_common/api_model.py +++ b/packages/nilai-common/src/nilai_common/api_model.py @@ -13,7 +13,7 @@ class ChatRequest(BaseModel): model: str messages: List[Message] temperature: Optional[float] = 0.2 - max_tokens: Optional[int] = None + max_tokens: Optional[int] = 2048 stream: Optional[bool] = False diff --git a/tests/nilai_models/models/test_llama_1b_cpu.py b/tests/nilai_models/models/test_llama_1b_cpu.py index d1a8a1d3..0731d073 100644 --- a/tests/nilai_models/models/test_llama_1b_cpu.py +++ b/tests/nilai_models/models/test_llama_1b_cpu.py @@ -7,11 +7,12 @@ from tests import response as RESPONSE +model = Llama1BCpu() + @pytest.fixture def llama_model(mocker): """Fixture to provide a Llama1BCpu instance for testing.""" - model = Llama1BCpu() mocker.patch.object(model, "chat_completion", new_callable=AsyncMock) return model