Skip to content

Commit

Permalink
introduce complete verification
Browse files Browse the repository at this point in the history
  • Loading branch information
LawyZheng committed Nov 15, 2024
1 parent 54f793c commit 34c74f8
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 36 deletions.
69 changes: 36 additions & 33 deletions skyvern/forge/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
Action,
ActionType,
CompleteAction,
CompleteVerifyResult,
DecisiveAction,
UserDefinedError,
WebAction,
Expand Down Expand Up @@ -923,57 +924,59 @@ async def agent_step(
)
return failed_step, detailed_agent_step_output.get_clean_detailed_output()

@staticmethod
async def complete_verify(page: Page, scraped_page: ScrapedPage, task: Task, step: Step) -> CompleteVerifyResult:
LOG.info(
"Checking if user goal is achieved after re-scraping the page",
task_id=task.task_id,
step_id=step.step_id,
workflow_run_id=task.workflow_run_id,
)
scraped_page_refreshed = await scraped_page.refresh()

# TODO: currently, just using the check user goal for complete verification
# maybe need a desinged complete criterion in the future
verification_prompt = prompt_engine.load_prompt(
"check-user-goal",
navigation_goal=task.navigation_goal,
navigation_payload=task.navigation_payload,
elements=scraped_page_refreshed.build_element_tree(ElementTreeFormat.HTML),
)

# this prompt is critical to our agent so let's use the primary LLM API handler
verification_result = await app.LLM_API_HANDLER(
prompt=verification_prompt, step=step, screenshots=scraped_page_refreshed.screenshots
)
return CompleteVerifyResult.model_validate(verification_result)

@staticmethod
async def check_user_goal_complete(
page: Page, scraped_page: ScrapedPage, task: Task, step: Step
) -> CompleteAction | None:
try:
LOG.info(
"Checking if user goal is achieved after re-scraping the page without screenshots",
task_id=task.task_id,
step_id=step.step_id,
workflow_run_id=task.workflow_run_id,
)
scraped_page_refreshed = await scraped_page.refresh()

verification_prompt = prompt_engine.load_prompt(
"check-user-goal",
navigation_goal=task.navigation_goal,
navigation_payload=task.navigation_payload,
elements=scraped_page_refreshed.build_element_tree(ElementTreeFormat.HTML),
verification_result = await app.agent.complete_verify(
page=page,
scraped_page=scraped_page,
task=task,
step=step,
)

# this prompt is critical to our agent so let's use the primary LLM API handler
verification_response = await app.LLM_API_HANDLER(
prompt=verification_prompt, step=step, screenshots=scraped_page_refreshed.screenshots
)
if "user_goal_achieved" not in verification_response or "thoughts" not in verification_response:
LOG.error(
"Invalid LLM response for user goal success verification, skipping verification",
verification_response=verification_response,
task_id=task.task_id,
step_id=step.step_id,
workflow_run_id=task.workflow_run_id,
)
return None

user_goal_achieved: bool = verification_response["user_goal_achieved"]
# We don't want to return a complete action if the user goal is not achieved since we're checking at every step
if not user_goal_achieved:
if not verification_result.user_goal_achieved:
return None

return CompleteAction(
reasoning=verification_response["thoughts"],
reasoning=verification_result.thoughts,
data_extraction_goal=task.data_extraction_goal,
verified=True,
)

except Exception:
LOG.error(
"LLM verification failed for complete action, skipping LLM verification",
LOG.exception(
"Failed to check user goal complete, skipping",
task_id=task.task_id,
step_id=step.step_id,
workflow_run_id=task.workflow_run_id,
exc_info=True,
)
return None

Expand Down
7 changes: 7 additions & 0 deletions skyvern/forge/prompts/skyvern/check-user-goal.j2
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,19 @@ Make sure to ONLY return the JSON object in this format with no additional text
"thoughts": str, // Think step by step. What information makes you believe whether user goal has completed or not. Use information you see on the site to explain.
"user_goal_achieved": bool // True if the user goal has been completed, false otherwise.
}
```

Elements on the page:
```
{{ elements }}
```

User Goal:
```
{{ navigation_goal }}
```

User Details:
```
{{ navigation_payload }}
```
10 changes: 10 additions & 0 deletions skyvern/webeye/actions/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ def __repr__(self) -> str:
return f"SelectOption(label={self.label}, value={self.value}, index={self.index})"


class CompleteVerifyResult(BaseModel):
user_goal_achieved: bool
thoughts: str
page_info: str | None = None

def __repr__(self) -> str:
return f"CompleteVerifyResponse(thoughts={self.thoughts}, user_goal_achieved={self.user_goal_achieved}, page_info={self.page_info})"


class InputOrSelectContext(BaseModel):
field: str | None = None
is_required: bool | None = None
Expand Down Expand Up @@ -226,6 +235,7 @@ class TerminateAction(DecisiveAction):

class CompleteAction(DecisiveAction):
action_type: ActionType = ActionType.COMPLETE
verified: bool = False
data_extraction_goal: str | None = None


Expand Down
41 changes: 38 additions & 3 deletions skyvern/webeye/actions/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,9 +1032,14 @@ async def handle_complete_action(
) -> list[ActionResult]:
# If this action has a source_action_id, then we need to make sure if the goal is actually completed.
if action.source_action_id:
LOG.info("CompleteAction has source_action_id, checking if goal is completed")
complete_action = await app.agent.check_user_goal_complete(page, scraped_page, task, step)
if complete_action is None:
LOG.info(
"CompleteAction has source_action_id, checking if goal is completed",
task_id=task.task_id,
step_id=step.step_id,
workflow_run_id=task.workflow_run_id,
)
verified_complete_action = await app.agent.check_user_goal_complete(page, scraped_page, task, step)
if verified_complete_action is None:
return [
ActionFailure(
exception=IllegitComplete(
Expand All @@ -1044,6 +1049,36 @@ async def handle_complete_action(
)
)
]
action.verified = True

if not action.verified:
LOG.info(
"CompleteAction hasn't been verified, going to verify the user goal",
task_id=task.task_id,
step_id=step.step_id,
workflow_run_id=task.workflow_run_id,
)
try:
verification_result = await app.agent.complete_verify(page, scraped_page, task, step)
except Exception as e:
LOG.exception(
"Failed to verify the complete action",
task_id=task.task_id,
step_id=step.step_id,
workflow_run_id=task.workflow_run_id,
)
return [ActionFailure(exception=e)]

if not verification_result.user_goal_achieved:
return [ActionFailure(exception=IllegitComplete(data={"error": verification_result.thoughts}))]

LOG.info(
"CompleteAction has been verified successfully",
task_id=task.task_id,
step_id=step.step_id,
workflow_run_id=task.workflow_run_id,
)
action.verified = True

extracted_data = None
if action.data_extraction_goal:
Expand Down

0 comments on commit 34c74f8

Please sign in to comment.