Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change the validation logic with regression validation logic, #357

Merged
131 changes: 109 additions & 22 deletions src/sherpa_ai/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from sherpa_ai.policies.base import BasePolicy
from sherpa_ai.verbose_loggers.base import BaseVerboseLogger
from sherpa_ai.verbose_loggers.verbose_loggers import DummyVerboseLogger

from langchain.base_language import BaseLanguageModel

# Avoid circular import
if TYPE_CHECKING:
Expand All @@ -32,7 +32,10 @@ def __init__(
validation_steps: int = 1,
validations: List[BaseOutputProcessor] = [],
feedback_agent_name: str = "critic",
global_regen_max: int = 12,
llm: BaseLanguageModel = None,
):
self.llm = llm
self.name = name
self.description = description
self.shared_memory = shared_memory
Expand All @@ -45,6 +48,7 @@ def __init__(
self.verbose_logger = verbose_logger
self.actions = actions
self.validation_steps = validation_steps
self.global_regen_max = global_regen_max
self.validations = validations
self.feedback_agent_name = feedback_agent_name

Expand Down Expand Up @@ -91,13 +95,67 @@ def run(self):
EventType.action_output, self.name, action_output
)

result = self.validate_output()
result = (
self.validate_output()
if len(self.validations) > 0
else self.synthesize_output()
)

logger.debug(f"```🤖{self.name} wrote: {result}```")

self.shared_memory.add(EventType.result, self.name, result)
return result

# The validation_iterator function is responsible for iterating through each instantiated validation in the 'self.validations' list.
# It performs the necessary validation steps for each validation, updating the belief system and synthesizing output if needed.
# It keeps track of the global regeneration count, whether all validations have passed, and if any validation has been escaped.
# The function returns the updated global regeneration count, the status of all validations, whether any validation has been escaped, and the synthesized output.
def validation_iterator(
self,
validations,
global_regen_count,
all_pass,
validation_is_scaped,
result,
):

for i in range(len(validations)):
validation = validations[i]
logger.info(f"validation_running: {validation.__class__.__name__}")
logger.info(f"validation_count: {validation.count}")
# this checks if the validator has already exceeded the validation steps limit.
if validation.count < self.validation_steps:
self.belief.update_internal(EventType.result, self.name, result)
validation_result = validation.process_output(
text=result, belief=self.belief, llm=self.llm
)
logger.info(f"validation_result: {validation_result}")
if not validation_result.is_valid:
self.belief.update_internal(
EventType.feedback,
self.feedback_agent_name,
validation_result.feedback,
)
result = self.synthesize_output()
global_regen_count += 1
break

# if all validations passed then set all_pass to True
elif i == len(validations) - 1:
result = validation_result.result
all_pass = True
else:
result = validation_result.result
# if validation is the last one and surpassed the validation steps limit then finish the loop with all_pass and mention there is a scaped validation.
elif i == len(validations) - 1:
validation_is_scaped = True
all_pass = True

else:
# if the validation has already reached the validation steps limit then continue to the next validation.
validation_is_scaped = True
return global_regen_count, all_pass, validation_is_scaped, result

def validate_output(self):
"""
Validate the synthesized output through a series of validation steps.
Expand All @@ -112,35 +170,64 @@ def validate_output(self):
Returns:
str: The synthesized output after validation.
"""
failed_validation = []
result = ""
# create array of instance of validation so that we can keep track of how many times regeneration happened.
all_pass = False
validation_is_scaped = False
iteration_count = 0
result = self.synthesize_output()
global_regen_count = 0

# reset the state of all the validation before starting the validation process.
for validation in self.validations:
for count in range(self.validation_steps):
self.belief.update_internal(EventType.result, self.name, result)
validation.reset_state()

validations = self.validations

# this loop will run until max regeneration reached or all validations have failed
while self.global_regen_max > global_regen_count and not all_pass:
logger.info(f"validations_size: {len(validations)}")
iteration_count += 1
logger.info(f"main_iteration: {iteration_count}")
logger.info(f"regen_count: {global_regen_count}")

global_regen_count, all_pass, validation_is_scaped, result = (
self.validation_iterator(
all_pass=all_pass,
global_regen_count=global_regen_count,
validation_is_scaped=validation_is_scaped,
validations=validations,
result=result,
)
)
# if all didn't pass or validation reached max regeneration run the validation one more time but no regeneration.
if validation_is_scaped or self.global_regen_max >= global_regen_count:
failed_validations = []

for validation in validations:
validation_result = validation.process_output(
text=result, belief=self.belief, iteration_count=count
text=result, belief=self.belief, llm=self.llm
)

if validation_result.is_valid:
result = validation_result.result
break
if not validation_result.is_valid:
failed_validations.append(validation)
else:
self.belief.update_internal(
EventType.feedback,
self.feedback_agent_name,
validation_result.feedback,
)
result = self.synthesize_output()

if count >= self.validation_steps:
failed_validation.append(validation)
result = validation_result.result

if len(failed_validation) > 0:
# if the validation failed after all steps, append the error messages to the result
result += "\n".join(
failed_validation.get_failure_message()
for failed_validation in failed_validation
for failed_validation in failed_validations
)

else:

# check if validation is not passed after all the attempts if so return the error message.
result += "\n".join(
20001LastOrder marked this conversation as resolved.
Show resolved Hide resolved
(
inst_val.get_failure_message()
if inst_val.count == self.validation_steps
else ""
)
for inst_val in validations
)

self.belief.update_internal(EventType.result, self.name, result)
Expand Down
24 changes: 14 additions & 10 deletions src/sherpa_ai/agents/qa_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
actions: List[BaseAction] = [],
validation_steps: int = 1,
validations: List[BaseOutputProcessor] = [],
global_regen_max: int = 5,
):
"""
The QA agent handles a single question-answering task.
Expand All @@ -63,16 +64,18 @@ def __init__(
validations (List[BaseOutputProcessor], optional): The list of validations the agent will perform. Defaults to [].
"""
super().__init__(
name,
description + "\n\n" + f"Your name is {name}.",
shared_memory,
belief,
policy,
num_runs,
verbose_logger,
actions,
validation_steps,
validations,
llm=llm,
name=name,
description=description + "\n\n" + f"Your name is {name}.",
shared_memory=shared_memory,
belief=belief,
policy=policy,
num_runs=num_runs,
verbose_logger=verbose_logger,
actions=actions,
validation_steps=validation_steps,
validations=validations,
global_regen_max=global_regen_max,
)

if self.policy is None:
Expand All @@ -89,6 +92,7 @@ def __init__(
belief = Belief()
self.belief = belief
self.citation_enabled = False

for validation in self.validations:
if isinstance(validation, CitationValidation):
self.citation_enabled = True
Expand Down
7 changes: 6 additions & 1 deletion src/sherpa_ai/output_parsers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class BaseOutputProcessor(ABC):
Defines the interface for processing output text.

Attributes:
None
count (int): Abstract global variable representing the count of failed validations.

Methods:
process_output(text: str) -> Tuple[bool, str]:
Expand All @@ -52,6 +52,11 @@ class BaseOutputProcessor(ABC):

"""

count: int = 0

def reset_state(self):
self.count = 0

@abstractmethod
def process_output(self, text: str, **kwargs) -> ValidationResult:
"""
Expand Down
18 changes: 14 additions & 4 deletions src/sherpa_ai/output_parsers/entity_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sherpa_ai.memory import Belief
from sherpa_ai.output_parsers.base import BaseOutputProcessor
from sherpa_ai.output_parsers.validation_result import ValidationResult
from langchain.base_language import BaseLanguageModel
from sherpa_ai.utils import (
extract_entities,
text_similarity,
Expand Down Expand Up @@ -36,7 +37,7 @@ class EntityValidation(BaseOutputProcessor):
"""

def process_output(
self, text: str, belief: Belief, iteration_count: int = 1
self, text: str, belief: Belief, llm: BaseLanguageModel = None, **kwargs
) -> ValidationResult:
"""
Verifies that entities within `text` exist in the `belief` source text.
Expand All @@ -58,7 +59,7 @@ def process_output(
exclude_types=[EventType.feedback, EventType.result],
)
entity_exist_in_source, error_message = self.check_entities_match(
text, source, self.similarity_picker(iteration_count)
text, source, self.similarity_picker(self.count), llm
)
if entity_exist_in_source:
return ValidationResult(
Expand All @@ -67,6 +68,7 @@ def process_output(
feedback="",
)
else:
self.count += 1
return ValidationResult(
is_valid=False,
result=text,
Expand All @@ -93,7 +95,11 @@ def get_failure_message(self) -> str:
return "Some enitities from the source might not be mentioned."

def check_entities_match(
self, result: str, source: str, stage: TextSimilarityMethod
self,
result: str,
source: str,
stage: TextSimilarityMethod,
llm: BaseLanguageModel,
):
"""
Check if entities extracted from a question are present in an answer.
Expand All @@ -118,9 +124,13 @@ def check_entities_match(
return text_similarity_by_metrics(
check_entity=check_entity, source_entity=source_entity
)
else:
elif stage > 1 and llm is not None:
return text_similarity_by_llm(
llm=llm,
source_entity=source_entity,
result=result,
source=source,
)
return text_similarity_by_metrics(
check_entity=check_entity, source_entity=source_entity
)
1 change: 1 addition & 0 deletions src/sherpa_ai/output_parsers/number_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def process_output(self, text: str, belief: Belief, **kwargs) -> ValidationResul
feedback="",
)
else:
self.count += 1
return ValidationResult(
is_valid=False,
result=text,
Expand Down
9 changes: 2 additions & 7 deletions src/sherpa_ai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from nltk.metrics import edit_distance, jaccard_distance
from pypdf import PdfReader
from word2number import w2n
from langchain.base_language import BaseLanguageModel

import sherpa_ai.config as cfg
from sherpa_ai.database.user_usage_tracker import UserUsageTracker
Expand Down Expand Up @@ -513,6 +514,7 @@ def json_from_text(text: str):


def text_similarity_by_llm(
llm: BaseLanguageModel,
source_entity: List[str],
source,
result,
Expand All @@ -533,13 +535,6 @@ def text_similarity_by_llm(
dict: Result of the check containing 'entity_exist' and 'messages'.
"""

llm = SherpaOpenAI(
temperature=cfg.TEMPERATURE,
openai_api_key=cfg.OPENAI_API_KEY,
user_id=user_id,
team_id=team_id,
)

instruction = f"""
I have a question and an answer. I want you to confirm whether the entities from the question are all mentioned in some form within the answer.

Expand Down
Loading