Skip to content

Commit

Permalink
feat: Add initial support for structured output (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
xmnlab authored Nov 21, 2024
1 parent b621e56 commit 64856e1
Show file tree
Hide file tree
Showing 11 changed files with 455 additions and 463 deletions.
664 changes: 228 additions & 436 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ pypdf = ">=5"
langchain = ">=0.3.7"
langchain-community = ">=0.3.7"
spacy = ">=3"
instructor = ">=1"
pydantic = ">=2"

[tool.poetry.extras]
cpu = ["torch", "torchvision"]
Expand Down
5 changes: 3 additions & 2 deletions src/rago/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from typing import Any

from pydantic import BaseModel
from typeguard import typechecked

from rago.augmented.base import AugmentedBase
Expand Down Expand Up @@ -50,7 +51,7 @@ def __init__(
'generation': generation.logs,
}

def prompt(self, query: str, device: str = 'auto') -> str:
def prompt(self, query: str, device: str = 'auto') -> str | BaseModel:
"""Run the pipeline for a specific prompt.
Parameters
Expand All @@ -72,7 +73,7 @@ def prompt(self, query: str, device: str = 'auto') -> str:
aug_data = self.augmented.search(query, ret_data)
self.logs['augmented']['result'] = aug_data

gen_data: str = self.generation.generate(query, context=aug_data)
gen_data = self.generation.generate(query, context=aug_data)
self.logs['generation']['result'] = gen_data

return gen_data
31 changes: 19 additions & 12 deletions src/rago/generation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from __future__ import annotations

from abc import abstractmethod
from typing import Any
from typing import Any, Optional, Type

import torch

from pydantic import BaseModel
from typeguard import typechecked


Expand All @@ -26,6 +27,7 @@ class GenerationBase:
prompt_template: str = (
'question: \n```\n{query}\n```\ncontext: ```\n{context}\n```'
)
structured_output: Optional[Type[BaseModel]] = None

# default parameters that can be overwritten by the derived class
default_device_name: str = 'cpu'
Expand All @@ -44,6 +46,7 @@ def __init__(
prompt_template: str = '',
output_max_length: int = 500,
device: str = 'auto',
structured_output: Optional[Type[BaseModel]] = None,
logs: dict[str, Any] = {},
) -> None:
"""Initialize Generation class.
Expand All @@ -58,29 +61,33 @@ def __init__(
output_max_length : int
Maximum length of the generated output.
device: str (default=auto)
structured_output: Optional[Type[BaseModel]] = None
logs: dict[str, Any] = {}
"""
self.api_key = api_key
self.model_name = model_name or self.default_model_name
self.output_max_length = (
self.api_key: str = api_key
self.model_name: str = model_name or self.default_model_name
self.output_max_length: int = (
output_max_length or self.default_output_max_length
)
self.temperature = temperature or self.default_temperature
self.temperature: float = temperature or self.default_temperature

self.prompt_template = prompt_template or self.default_prompt_template
self.prompt_template: str = (
prompt_template or self.default_prompt_template
)
self.structured_output: Optional[Type[BaseModel]] = structured_output

if self.device_name not in ['cpu', 'cuda', 'auto']:
if device not in ['cpu', 'cuda', 'auto']:
raise Exception(
f'Device {self.device_name} not supported. '
'Options: cpu, cuda, auto.'
f'Device {device} not supported. ' 'Options: cpu, cuda, auto.'
)

cuda_available = torch.cuda.is_available()
self.device_name = (
self.device_name: str = (
'cpu' if device == 'cpu' or not cuda_available else 'cuda'
)
self.device = torch.device(self.device_name)

self.logs = logs
self.logs: dict[str, Any] = logs

self._validate()
self._setup()
Expand All @@ -98,7 +105,7 @@ def generate(
self,
query: str,
context: list[str],
) -> str:
) -> str | BaseModel:
"""Generate text with optional language parameter.
Parameters
Expand Down
34 changes: 29 additions & 5 deletions src/rago/generation/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from typing import cast

import google.generativeai as genai
import instructor

from pydantic import BaseModel
from typeguard import typechecked

from rago.generation.base import GenerationBase
Expand All @@ -20,19 +22,41 @@ class GeminiGen(GenerationBase):
def _setup(self) -> None:
"""Set up the object with the initial parameters."""
genai.configure(api_key=self.api_key)
self.model = genai.GenerativeModel(self.model_name)
model = genai.GenerativeModel(self.model_name)

self.model = (
instructor.from_gemini(
client=model,
mode=instructor.Mode.GEMINI_JSON,
)
if self.structured_output
else model
)

def generate(self, query: str, context: list[str]) -> str:
def generate(self, query: str, context: list[str]) -> str | BaseModel:
"""Generate text using Gemini model support."""
input_text = self.prompt_template.format(
query=query, context=' '.join(context)
)

if not self.structured_output:
models_params_gen = {'contents': input_text}
response = self.model.generate_content(**models_params_gen)
self.logs['model_params'] = models_params_gen
return cast(str, response.text.strip())

messages = [
{'role': 'user', 'content': input_text},
]
model_params = {
'contents': input_text,
'messages': messages,
'response_model': self.structured_output,
}

response = self.model.generate_content(**model_params)
response = self.model.create(
**model_params,
)

self.logs['model_params'] = model_params
return cast(str, response.text.strip())

return cast(BaseModel, response)
14 changes: 10 additions & 4 deletions src/rago/generation/hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

import warnings

import torch

from transformers import T5ForConditionalGeneration, T5Tokenizer
Expand All @@ -22,13 +24,17 @@ def _validate(self) -> None:
f'The given model {self.model_name} is not supported.'
)

if self.structured_output:
warnings.warn(
'Structured output is not supported yet in '
f'{self.__class__.__name__}.'
)

def _setup(self) -> None:
"""Set models to t5-small models."""
self.tokenizer = T5Tokenizer.from_pretrained(self.model_name)
self.model = T5ForConditionalGeneration.from_pretrained(
self.model_name
)
self.model = self.model.to(self.device)
model = T5ForConditionalGeneration.from_pretrained(self.model_name)
self.model = model.to(self.device)

def generate(self, query: str, context: list[str]) -> str:
"""Generate the text from the query and augmented context."""
Expand Down
8 changes: 8 additions & 0 deletions src/rago/generation/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

import warnings

import torch

from langdetect import detect
Expand All @@ -27,6 +29,12 @@ def _validate(self) -> None:
'by meta.'
)

if self.structured_output:
warnings.warn(
'Structured output is not supported yet in '
f'{self.__class__.__name__}.'
)

def _setup(self) -> None:
"""Set up the object with the initial parameters."""
self.tokenizer = AutoTokenizer.from_pretrained(
Expand Down
19 changes: 16 additions & 3 deletions src/rago/generation/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

from typing import cast

import instructor
import openai

from pydantic import BaseModel
from typeguard import typechecked

from rago.generation.base import GenerationBase
Expand All @@ -19,13 +21,17 @@ class OpenAIGen(GenerationBase):

def _setup(self) -> None:
"""Set up the object with the initial parameters."""
self.model = openai.OpenAI(api_key=self.api_key)
model = openai.OpenAI(api_key=self.api_key)

self.model = (
instructor.from_openai(model) if self.structured_output else model
)

def generate(
self,
query: str,
context: list[str],
) -> str:
) -> str | BaseModel:
"""Generate text using OpenAI's API with dynamic model support."""
input_text = self.prompt_template.format(
query=query, context=' '.join(context)
Expand All @@ -44,8 +50,15 @@ def generate(
presence_penalty=0.3,
)

if self.structured_output:
model_params['response_model'] = self.structured_output

response = self.model.chat.completions.create(**model_params)

self.logs['model_params'] = model_params

return cast(str, response.choices[0].message.content.strip())
has_choices = hasattr(response, 'choices')

if has_choices and isinstance(response.choices, list):
return cast(str, response.choices[0].message.content.strip())
return cast(BaseModel, response)
28 changes: 28 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Models used for the unit tests."""

from __future__ import annotations

from typing import Literal

from pydantic import BaseModel, Field


class AnimalModel(BaseModel):
"""Model for animals."""

name: Literal[
'Blue Whale',
'Peregrine Falcon',
'Giant Panda',
'Cheetah',
'Komodo Dragon',
'Arctic Fox',
'Monarch Butterfly',
'Great White Shark',
'Honey Bee',
'Emperor Penguin',
'Unknown',
] = Field(
...,
description='The predicted class label.',
)
54 changes: 54 additions & 0 deletions tests/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@

import os

from typing import cast

import pytest

from rago import Rago
from rago.augmented import SentenceTransformerAug
from rago.generation import GeminiGen
from rago.retrieval import StringRet

from .models import AnimalModel


@pytest.fixture
def api_key(env) -> str:
Expand Down Expand Up @@ -52,3 +56,53 @@ def test_gemini_generation(animals_data: list[str], api_key: str) -> None:
assert logs['retrieval']
assert logs['augmented']
assert logs['generation']


@pytest.mark.skip_on_ci
@pytest.mark.parametrize(
'question,expected_answer',
[
('What animal is larger than a dinosaur?', 'Blue Whale'),
(
'What animal is renowned as the fastest animal on the planet?',
'Peregrine Falcon',
),
],
)
def test_rag_gemini_structured_output(
api_key: str,
animals_data: list[str],
question: str,
expected_answer: str,
) -> None:
"""Test RAG pipeline with Gemini."""
logs = {
'retrieval': {},
'augmented': {},
'generation': {},
}

rag = Rago(
retrieval=StringRet(animals_data, logs=logs['retrieval']),
augmented=SentenceTransformerAug(top_k=3, logs=logs['augmented']),
generation=GeminiGen(
api_key=api_key,
model_name='gemini-1.5-flash',
logs=logs['generation'],
structured_output=AnimalModel,
),
)

result = cast(AnimalModel, rag.prompt(question))

error_message = (
f'Expected response to mention `{expected_answer}`. '
f'Result: `{result.name}`.'
)

assert expected_answer == result.name, error_message

# check if logs have been used
assert logs['retrieval']
assert logs['augmented']
assert logs['generation']
Loading

0 comments on commit 64856e1

Please sign in to comment.