From 4da07bd360a87419c598c9480bf61c0334cad572 Mon Sep 17 00:00:00 2001 From: Carles Onielfa Date: Thu, 5 Dec 2024 11:51:06 +0100 Subject: [PATCH] Add NUA REMi endpoint to Nucloa SDK --- nuclia/lib/nua.py | 23 +++++++++++++ nuclia/sdk/predict.py | 23 +++++++++++++ nuclia/tests/test_nua/test_predict.py | 47 +++++++++++++++++++++++++++ 3 files changed, 93 insertions(+) diff --git a/nuclia/lib/nua.py b/nuclia/lib/nua.py index ceaa5e4..7083110 100644 --- a/nuclia/lib/nua.py +++ b/nuclia/lib/nua.py @@ -49,6 +49,7 @@ TextGenerativeResponse, Tokens, ) +from nuclia_models.predict.remi import RemiRequest, RemiResponse SENTENCE_PREDICT = "/api/v1/predict/sentence" CHAT_PREDICT = "/api/v1/predict/chat" @@ -56,6 +57,7 @@ REPHRASE_PREDICT = "/api/v1/predict/rephrase" TOKENS_PREDICT = "/api/v1/predict/tokens" QUERY_PREDICT = "/api/v1/predict/query" +REMI_PREDICT = "/api/v1/predict/remi" UPLOAD_PROCESS = "/api/v1/processing/upload" STATUS_PROCESS = "/api/v2/processing/status" PUSH_PROCESS = "/api/v2/processing/push" @@ -286,6 +288,18 @@ def rephrase( output=RephraseModel, ) + def remi( + self, + request: RemiRequest, + ) -> RemiResponse: + endpoint = f"{self.url}{REMI_PREDICT}" + return self._request( + "POST", + endpoint, + payload=request.model_dump(), + output=RemiResponse, + ) + def process_file(self, path: str, kbid: str = "default") -> PushResponseV2: filename = path.split("/")[-1] upload_endpoint = f"{self.url}{UPLOAD_PROCESS}" @@ -594,6 +608,15 @@ async def rephrase( output=RephraseModel, ) + async def remi(self, request: RemiRequest) -> RemiResponse: + endpoint = f"{self.url}{REMI_PREDICT}" + return await self._request( + "POST", + endpoint, + payload=request.model_dump(), + output=RemiResponse, + ) + async def generate_retrieval( self, question: str, diff --git a/nuclia/sdk/predict.py b/nuclia/sdk/predict.py index 4097314..e12503a 100644 --- a/nuclia/sdk/predict.py +++ b/nuclia/sdk/predict.py @@ -18,6 +18,7 @@ UserPrompt, ) from nuclia.sdk.auth import NucliaAuth +from nuclia_models.predict.remi import RemiRequest, RemiResponse class NucliaPredict: @@ -142,6 +143,17 @@ def rag( return nc.generate(body, model) + @nua + def remi(self, request: RemiRequest, **kwargs) -> RemiResponse: + """ + Perform a REMi evaluation over a RAG experience + + :param request: RemiRequest + :return: RemiResponse + """ + nc: NuaClient = kwargs["nc"] + return nc.remi(request) + class AsyncNucliaPredict: @property @@ -257,3 +269,14 @@ async def rag( ) -> ChatResponse: nc: AsyncNuaClient = kwargs["nc"] return await nc.generate_retrieval(question, context, model) + + @nua + async def remi(self, request: RemiRequest, **kwargs) -> RemiResponse: + """ + Perform a REMi evaluation over a RAG experience + + :param request: RemiRequest + :return: RemiResponse + """ + nc: AsyncNuaClient = kwargs["nc"] + return await nc.remi(request) diff --git a/nuclia/tests/test_nua/test_predict.py b/nuclia/tests/test_nua/test_predict.py index 1bc3bb6..7f624c2 100644 --- a/nuclia/tests/test_nua/test_predict.py +++ b/nuclia/tests/test_nua/test_predict.py @@ -1,5 +1,7 @@ from nuclia.lib.nua_responses import ChatModel, TextGenerativeResponse, UserPrompt from nuclia.sdk.predict import AsyncNucliaPredict, NucliaPredict +import pytest +from nuclia_models.predict.remi import RemiRequest def test_predict(testing_config): @@ -119,3 +121,48 @@ async def test_nua_parse(testing_config): ) ) assert "SPORTS" in results.object["document_type"] + + +def test_nua_remi(testing_config): + np = NucliaPredict() + results = np.remi( + RemiRequest( + user_id="Nuclia PY CLI", + question="What is the capital of France?", + answer="Paris is the capital of france!", + contexts=[ + "Paris is the capital of France.", + "Berlin is the capital of Germany.", + ], + ) + ) + assert results.answer_relevance.score >= 4 + + assert results.context_relevance[0] >= 4 + assert results.groundedness[0] >= 4 + + assert results.context_relevance[1] < 2 + assert results.groundedness[1] < 2 + + +@pytest.mark.asyncio +async def test_nua_async_remi(testing_config): + np = AsyncNucliaPredict() + results = await np.remi( + RemiRequest( + user_id="Nuclia PY CLI", + question="What is the capital of France?", + answer="Paris is the capital of france!", + contexts=[ + "Paris is the capital of France.", + "Berlin is the capital of Germany.", + ], + ) + ) + assert results.answer_relevance.score >= 4 + + assert results.context_relevance[0] >= 4 + assert results.groundedness[0] >= 4 + + assert results.context_relevance[1] < 2 + assert results.groundedness[1] < 2