From 44f93d1fd0655d353f437befc1d0d8a0623cd362 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Mon, 23 Dec 2024 07:00:42 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=84=20synced=20local=20'skyvern/'=20wi?= =?UTF-8?q?th=20remote=20'skyvern/'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit > [!IMPORTANT] > Enhance observer thought handling by adding `OBSERVER_THOUGHT` log entity type and updating observer thought processes in `client.py` and `observer_service.py`. > > - **Behavior**: > - Add `OBSERVER_THOUGHT` to `LogEntityType` in `models.py`. > - Implement `update_observer_thought()` in `client.py` to update observer thoughts with fields like `workflow_run_block_id`, `observation`, `thought`, and `answer`. > - Modify `run_observer_cruise()` in `observer_service.py` to create and update observer thoughts during the observer cruise process. > - **Database**: > - Add `update_observer_thought()` method in `client.py` to handle updates to `ObserverThoughtModel`. > - **Service**: > - Update `run_observer_cruise()` in `observer_service.py` to use `create_observer_thought()` and `update_observer_thought()` for handling observer thoughts. > - **Misc**: > - Add `OBSERVER_THOUGHT` to `EntityType` in `agent_protocol.py`. > > This description was created by [Ellipsis](https://www.ellipsis.dev?ref=Skyvern-AI%2Fskyvern-cloud&utm_source=github&utm_medium=referral) for fd92b851c149a2d27cec0fc4121e6e020c950c99. It will automatically update as commits are pushed. --- skyvern/forge/sdk/artifact/models.py | 1 + skyvern/forge/sdk/db/client.py | 31 +++++++++++++++++++ skyvern/forge/sdk/routes/agent_protocol.py | 2 ++ .../forge/sdk/services/observer_service.py | 18 +++++++---- 4 files changed, 46 insertions(+), 6 deletions(-) diff --git a/skyvern/forge/sdk/artifact/models.py b/skyvern/forge/sdk/artifact/models.py index 307c902bf4..c91a37a400 100644 --- a/skyvern/forge/sdk/artifact/models.py +++ b/skyvern/forge/sdk/artifact/models.py @@ -84,3 +84,4 @@ class LogEntityType(StrEnum): TASK = "task" WORKFLOW_RUN = "workflow_run" WORKFLOW_RUN_BLOCK = "workflow_run_block" + OBSERVER = "observer" diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 0460f08ba5..74f1f38d2a 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -1975,6 +1975,37 @@ async def create_observer_thought( await session.refresh(new_observer_thought) return ObserverThought.model_validate(new_observer_thought) + async def update_observer_thought( + self, + observer_thought_id: str, + workflow_run_block_id: str | None = None, + observation: str | None = None, + thought: str | None = None, + answer: str | None = None, + organization_id: str | None = None, + ) -> ObserverThought: + async with self.Session() as session: + observer_thought = ( + await session.scalars( + select(ObserverThoughtModel) + .filter_by(observer_thought_id=observer_thought_id) + .filter_by(organization_id=organization_id) + ) + ).first() + if observer_thought: + if workflow_run_block_id: + observer_thought.workflow_run_block_id = workflow_run_block_id + if observation: + observer_thought.observation = observation + if thought: + observer_thought.thought = thought + if answer: + observer_thought.answer = answer + await session.commit() + await session.refresh(observer_thought) + return ObserverThought.model_validate(observer_thought) + raise NotFoundError(f"ObserverThought {observer_thought_id}") + async def update_observer_cruise( self, observer_cruise_id: str, diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index 020d971271..82e52ff409 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -478,6 +478,7 @@ class EntityType(str, Enum): TASK = "task" WORKFLOW_RUN = "workflow_run" WORKFLOW_RUN_BLOCK = "workflow_run_block" + OBSERVER_THOUGHT = "observer_thought" entity_type_to_param = { @@ -485,6 +486,7 @@ class EntityType(str, Enum): EntityType.TASK: "task_id", EntityType.WORKFLOW_RUN: "workflow_run_id", EntityType.WORKFLOW_RUN_BLOCK: "workflow_run_block_id", + EntityType.OBSERVER_THOUGHT: "observer_thought_id", } diff --git a/skyvern/forge/sdk/services/observer_service.py b/skyvern/forge/sdk/services/observer_service.py index fdbc5473bc..92de49b0e6 100644 --- a/skyvern/forge/sdk/services/observer_service.py +++ b/skyvern/forge/sdk/services/observer_service.py @@ -231,8 +231,17 @@ async def run_observer_cruise( task_history=task_history, local_datetime=datetime.now(context.tz_info).isoformat(), ) + observer_thought = await app.DATABASE.create_observer_thought( + observer_cruise_id=observer_cruise_id, + organization_id=organization_id, + workflow_run_id=workflow_run.workflow_run_id, + workflow_id=workflow.workflow_id, + workflow_permanent_id=workflow.workflow_permanent_id, + ) observer_response = await app.LLM_API_HANDLER( - prompt=observer_prompt, screenshots=scraped_page.screenshots, observer_cruise=observer_cruise + prompt=observer_prompt, + screenshots=scraped_page.screenshots, + observer_thought=observer_thought, ) LOG.info( "Observer response", @@ -247,12 +256,9 @@ async def run_observer_cruise( thoughts: str = observer_response.get("thoughts", "") plan: str = observer_response.get("plan", "") # Create and save observer thought - await app.DATABASE.create_observer_thought( - observer_cruise_id=observer_cruise_id, + await app.DATABASE.update_observer_thought( + observer_thought_id=observer_thought.observer_thought_id, organization_id=organization_id, - workflow_run_id=workflow_run.workflow_run_id, - workflow_id=workflow.workflow_id, - workflow_permanent_id=workflow.workflow_permanent_id, thought=thoughts, observation=observation, answer=plan,