Skip to content

Commit

Permalink
Refactor validation logic and add global regeneration max parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
Eyobyb committed May 29, 2024
1 parent 9e1cde9 commit 8c17cf4
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 26 deletions.
39 changes: 19 additions & 20 deletions src/sherpa_ai/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,22 +106,21 @@ def run(self):
self.shared_memory.add(EventType.result, self.name, result)
return result


# The validation_iterator function is responsible for iterating through each instantiated validation in the 'instantiated_validations' list.
# The validation_iterator function is responsible for iterating through each instantiated validation in the 'self.validations' list.
# It performs the necessary validation steps for each validation, updating the belief system and synthesizing output if needed.
# It keeps track of the global regeneration count, whether all validations have passed, and if any validation has been escaped.
# The function returns the updated global regeneration count, the status of all validations, whether any validation has been escaped, and the synthesized output.
def validation_iterator(
self,
instantiated_validations,
validations,
global_regen_count,
all_pass,
validation_is_scaped,
result,
):
for i in range(len(instantiated_validations)):
validation = instantiated_validations[i]

for i in range(len(validations)):
validation = validations[i]
logger.info(f"validation_running: {validation.__class__.__name__}")
logger.info(f"validation_count: {validation.count}")
# this checks if the validator has already exceeded the validation steps limit.
Expand All @@ -142,16 +141,16 @@ def validation_iterator(
break

# if all validations passed then set all_pass to True
elif i == len(instantiated_validations) - 1:
elif i == len(validations) - 1:
result = validation_result.result
all_pass = True
else:
result = validation_result.result
# if validation is the last one and surpassed the validation steps limit then finish the loop with all_pass and mention there is a scaped validation.
elif i == len(instantiated_validations) - 1:
elif i == len(validations) - 1:
validation_is_scaped = True
all_pass = True

else:
# if the validation has already reached the validation steps limit then continue to the next validation.
validation_is_scaped = True
Expand All @@ -173,21 +172,21 @@ def validate_output(self):
"""
result = ""
# create array of instance of validation so that we can keep track of how many times regeneration happened.
instantiated_validations = []
for validation in self.validations:
try:
instantiated_validations.append(validation())
except Exception as e:
instantiated_validations.append(validation)

all_pass = False
validation_is_scaped = False
iteration_count = 0
result = self.synthesize_output()
global_regen_count = 0

# reset the state of all the validation before starting the validation process.
for validation in self.validations:
validation.reset_state()

validations = self.validations

# this loop will run until max regeneration reached or all validations have failed
while self.global_regen_max > global_regen_count and not all_pass:
logger.info(f"validations_size: {len(instantiated_validations)}")
logger.info(f"validations_size: {len(validations)}")
iteration_count += 1
logger.info(f"main_iteration: {iteration_count}")
logger.info(f"regen_count: {global_regen_count}")
Expand All @@ -197,15 +196,15 @@ def validate_output(self):
all_pass=all_pass,
global_regen_count=global_regen_count,
validation_is_scaped=validation_is_scaped,
instantiated_validations=instantiated_validations,
validations=validations,
result=result,
)
)
# if all didn't pass or validation reached max regeneration run the validation one more time but no regeneration.
if validation_is_scaped or self.global_regen_max >= global_regen_count:
failed_validations = []

for validation in instantiated_validations:
for validation in validations:
validation_result = validation.process_output(
text=result, belief=self.belief, llm=self.llm
)
Expand All @@ -228,7 +227,7 @@ def validate_output(self):
if inst_val.count == self.validation_steps
else ""
)
for inst_val in instantiated_validations
for inst_val in validations
)

self.belief.update_internal(EventType.result, self.name, result)
Expand Down
3 changes: 3 additions & 0 deletions src/sherpa_ai/output_parsers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class BaseOutputProcessor(ABC):

count: int = 0

def reset_state(self):
self.count = 0

@abstractmethod
def process_output(self, text: str, **kwargs) -> ValidationResult:
"""
Expand Down
12 changes: 6 additions & 6 deletions src/tests/integration_tests/test_regression_validator_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ def test_regression_validator_flow(
agent_pool=None,
)

entity_validation = EntityValidation
number_validation = NumberValidation
number_validation_two = NumberValidation
citation_validation = CitationValidation
entity_validation = EntityValidation()
number_validation = NumberValidation()
number_validation_two = NumberValidation()
citation_validation = CitationValidation()
with patch.object(SearchTool, "_run", return_value=data):
task_agent = QAAgent(
llm=llm,
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_regression_validator_flow(
logger.info(results[0].content)
final_result = results[0].content

if not entity_validation().get_failure_message() in final_result:
if not entity_validation.get_failure_message() in final_result:
result_entities = [s.lower() for s in extract_entities(results[0].content)]
expected_entities = [s.lower() for s in expected_entities]
for entity in expected_entities:
Expand All @@ -96,7 +96,7 @@ def test_regression_validator_flow(

if not match_found:
assert False, entity + " was not found in resource"
if not number_validation().get_failure_message() in final_result:
if not number_validation.get_failure_message() in final_result:
for number in expected_number:
if not number in combined_number_extractor(results[0].content):
assert False, number + " was not found in resource"
Expand Down

0 comments on commit 8c17cf4

Please sign in to comment.