Skip to content

Commit

Permalink
Add NUA REMi endpoint to Nucloa SDK
Browse files Browse the repository at this point in the history
  • Loading branch information
carlesonielfa committed Dec 5, 2024
1 parent 5be431e commit 4da07bd
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 0 deletions.
23 changes: 23 additions & 0 deletions nuclia/lib/nua.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,15 @@
TextGenerativeResponse,
Tokens,
)
from nuclia_models.predict.remi import RemiRequest, RemiResponse

SENTENCE_PREDICT = "/api/v1/predict/sentence"
CHAT_PREDICT = "/api/v1/predict/chat"
SUMMARIZE_PREDICT = "/api/v1/predict/summarize"
REPHRASE_PREDICT = "/api/v1/predict/rephrase"
TOKENS_PREDICT = "/api/v1/predict/tokens"
QUERY_PREDICT = "/api/v1/predict/query"
REMI_PREDICT = "/api/v1/predict/remi"
UPLOAD_PROCESS = "/api/v1/processing/upload"
STATUS_PROCESS = "/api/v2/processing/status"
PUSH_PROCESS = "/api/v2/processing/push"
Expand Down Expand Up @@ -286,6 +288,18 @@ def rephrase(
output=RephraseModel,
)

def remi(
self,
request: RemiRequest,
) -> RemiResponse:
endpoint = f"{self.url}{REMI_PREDICT}"
return self._request(
"POST",
endpoint,
payload=request.model_dump(),
output=RemiResponse,
)

def process_file(self, path: str, kbid: str = "default") -> PushResponseV2:
filename = path.split("/")[-1]
upload_endpoint = f"{self.url}{UPLOAD_PROCESS}"
Expand Down Expand Up @@ -594,6 +608,15 @@ async def rephrase(
output=RephraseModel,
)

async def remi(self, request: RemiRequest) -> RemiResponse:
endpoint = f"{self.url}{REMI_PREDICT}"
return await self._request(
"POST",
endpoint,
payload=request.model_dump(),
output=RemiResponse,
)

async def generate_retrieval(
self,
question: str,
Expand Down
23 changes: 23 additions & 0 deletions nuclia/sdk/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
UserPrompt,
)
from nuclia.sdk.auth import NucliaAuth
from nuclia_models.predict.remi import RemiRequest, RemiResponse


class NucliaPredict:
Expand Down Expand Up @@ -142,6 +143,17 @@ def rag(

return nc.generate(body, model)

@nua
def remi(self, request: RemiRequest, **kwargs) -> RemiResponse:
"""
Perform a REMi evaluation over a RAG experience
:param request: RemiRequest
:return: RemiResponse
"""
nc: NuaClient = kwargs["nc"]
return nc.remi(request)


class AsyncNucliaPredict:
@property
Expand Down Expand Up @@ -257,3 +269,14 @@ async def rag(
) -> ChatResponse:
nc: AsyncNuaClient = kwargs["nc"]
return await nc.generate_retrieval(question, context, model)

@nua
async def remi(self, request: RemiRequest, **kwargs) -> RemiResponse:
"""
Perform a REMi evaluation over a RAG experience
:param request: RemiRequest
:return: RemiResponse
"""
nc: AsyncNuaClient = kwargs["nc"]
return await nc.remi(request)
47 changes: 47 additions & 0 deletions nuclia/tests/test_nua/test_predict.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from nuclia.lib.nua_responses import ChatModel, TextGenerativeResponse, UserPrompt
from nuclia.sdk.predict import AsyncNucliaPredict, NucliaPredict
import pytest
from nuclia_models.predict.remi import RemiRequest


def test_predict(testing_config):
Expand Down Expand Up @@ -119,3 +121,48 @@ async def test_nua_parse(testing_config):
)
)
assert "SPORTS" in results.object["document_type"]


def test_nua_remi(testing_config):
np = NucliaPredict()
results = np.remi(
RemiRequest(
user_id="Nuclia PY CLI",
question="What is the capital of France?",
answer="Paris is the capital of france!",
contexts=[
"Paris is the capital of France.",
"Berlin is the capital of Germany.",
],
)
)
assert results.answer_relevance.score >= 4

assert results.context_relevance[0] >= 4
assert results.groundedness[0] >= 4

assert results.context_relevance[1] < 2
assert results.groundedness[1] < 2


@pytest.mark.asyncio
async def test_nua_async_remi(testing_config):
np = AsyncNucliaPredict()
results = await np.remi(
RemiRequest(
user_id="Nuclia PY CLI",
question="What is the capital of France?",
answer="Paris is the capital of france!",
contexts=[
"Paris is the capital of France.",
"Berlin is the capital of Germany.",
],
)
)
assert results.answer_relevance.score >= 4

assert results.context_relevance[0] >= 4
assert results.groundedness[0] >= 4

assert results.context_relevance[1] < 2
assert results.groundedness[1] < 2

0 comments on commit 4da07bd

Please sign in to comment.