Skip to content

Commit

Permalink
new API to the workflow run page (Skyvern-AI#1400)
Browse files Browse the repository at this point in the history
  • Loading branch information
wintonzheng authored Dec 18, 2024
1 parent b8e2527 commit 58413db
Show file tree
Hide file tree
Showing 8 changed files with 403 additions and 178 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""add index to ObserverThoughtModel db table
Revision ID: cf45479f484c
Revises: 411dd89f3df9
Create Date: 2024-12-17 06:51:04.086890+00:00
"""

from typing import Sequence, Union

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "cf45479f484c"
down_revision: Union[str, None] = "411dd89f3df9"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_index(
"observer_cruise_index", "observer_thoughts", ["organization_id", "observer_cruise_id"], unique=False
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("observer_cruise_index", table_name="observer_thoughts")
# ### end Alembic commands ###
2 changes: 1 addition & 1 deletion skyvern/forge/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@
DecisiveAction,
UserDefinedError,
WebAction,
parse_actions,
)
from skyvern.webeye.actions.caching import retrieve_action_plan
from skyvern.webeye.actions.handler import ActionHandler, poll_verification_code
from skyvern.webeye.actions.models import AgentStepOutput, DetailedAgentStepOutput
from skyvern.webeye.actions.parse_actions import parse_actions
from skyvern.webeye.actions.responses import ActionResult
from skyvern.webeye.browser_factory import BrowserState
from skyvern.webeye.scraper.scraper import ElementTreeFormat, ScrapedPage, scrape_website
Expand Down
35 changes: 35 additions & 0 deletions skyvern/forge/sdk/db/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,25 @@ async def get_task_actions(self, task_id: str, organization_id: str | None = Non
LOG.error("UnexpectedError", exc_info=True)
raise

async def get_tasks_actions(self, task_ids: list[str], organization_id: str | None = None) -> list[Action]:
try:
async with self.Session() as session:
query = (
select(ActionModel)
.filter(ActionModel.organization_id == organization_id)
.filter(ActionModel.task_id.in_(task_ids))
.order_by(ActionModel.created_at)
)
actions = (await session.scalars(query)).all()
return [Action.model_validate(action) for action in actions]

except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise

async def get_first_step(self, task_id: str, organization_id: str | None = None) -> Step | None:
try:
async with self.Session() as session:
Expand Down Expand Up @@ -1858,6 +1877,22 @@ async def get_observer_thought(
return ObserverThought.model_validate(observer_thought)
return None

async def get_observer_cruise_thoughts(
self,
observer_cruise_id: str,
organization_id: str | None = None,
) -> list[ObserverThought]:
async with self.Session() as session:
observer_thoughts = (
await session.scalars(
select(ObserverThoughtModel)
.filter_by(observer_cruise_id=observer_cruise_id)
.filter_by(organization_id=organization_id)
.order_by(ObserverThoughtModel.created_at)
)
).all()
return [ObserverThought.model_validate(thought) for thought in observer_thoughts]

async def create_observer_cruise(
self,
workflow_run_id: str | None = None,
Expand Down
1 change: 1 addition & 0 deletions skyvern/forge/sdk/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ class ObserverCruiseModel(Base):

class ObserverThoughtModel(Base):
__tablename__ = "observer_thoughts"
__table_args__ = (Index("observer_cruise_index", "organization_id", "observer_cruise_id"),)

observer_thought_id = Column(String, primary_key=True, default=generate_observer_thought_id)
organization_id = Column(String, ForeignKey("organizations.organization_id"), nullable=True)
Expand Down
89 changes: 88 additions & 1 deletion skyvern/forge/sdk/routes/agent_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.permissions.permission_checker_factory import PermissionCheckerFactory
from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskType
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.organizations import (
Expand All @@ -52,12 +52,14 @@
TaskResponse,
TaskStatus,
)
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock, WorkflowRunEvent, WorkflowRunEventType
from skyvern.forge.sdk.services import org_auth_service
from skyvern.forge.sdk.workflow.exceptions import (
FailedToCreateWorkflow,
FailedToUpdateWorkflow,
WorkflowParameterMissingRequiredValue,
)
from skyvern.forge.sdk.workflow.models.block import BlockType
from skyvern.forge.sdk.workflow.models.workflow import (
RunWorkflowResponse,
Workflow,
Expand Down Expand Up @@ -720,6 +722,91 @@ async def get_workflow_run(
)


@base_router.get(
"/workflows/{workflow_id}/runs/{workflow_run_id}/events",
)
@base_router.get(
"/workflows/{workflow_id}/runs/{workflow_run_id}/events/",
)
async def get_workflow_run_events(
workflow_id: str,
workflow_run_id: str,
observer_cruise_id: str | None = None,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1),
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> list[WorkflowRunEvent]:
# get all the tasks for the workflow run
tasks = await app.DATABASE.get_tasks(
page,
page_size,
workflow_run_id=workflow_run_id,
organization_id=current_org.organization_id,
)
workflow_run_events: list[WorkflowRunEvent] = []
for task in tasks:
block_type = BlockType.TASK
if task.task_type == TaskType.general:
if not task.navigation_goal and task.data_extraction_goal:
block_type = BlockType.EXTRACTION
elif task.navigation_goal and not task.data_extraction_goal:
block_type = BlockType.NAVIGATION
elif task.task_type == TaskType.validation:
block_type = BlockType.VALIDATION
elif task.task_type == TaskType.action:
block_type = BlockType.ACTION
event = WorkflowRunEvent(
type=WorkflowRunEventType.block,
block=WorkflowRunBlock(
workflow_run_id=workflow_run_id,
block_type=block_type,
label=task.title,
title=task.title,
url=task.url,
status=task.status,
navigation_goal=task.navigation_goal,
data_extraction_goal=task.data_extraction_goal,
data_schema=task.extracted_information_schema,
terminate_criterion=task.terminate_criterion,
complete_criterion=task.complete_criterion,
created_at=task.created_at,
modified_at=task.modified_at,
),
created_at=task.created_at,
modified_at=task.modified_at,
)
workflow_run_events.append(event)
# get all the actions for all the tasks
actions = await app.DATABASE.get_tasks_actions(
[task.task_id for task in tasks], organization_id=current_org.organization_id
)
for action in actions:
workflow_run_events.append(
WorkflowRunEvent(
type=WorkflowRunEventType.action,
action=action,
created_at=action.created_at or datetime.datetime.utcnow(),
modified_at=action.modified_at or datetime.datetime.utcnow(),
)
)
# get all the thoughts for the cruise
if observer_cruise_id:
thoughts = await app.DATABASE.get_observer_cruise_thoughts(
observer_cruise_id, organization_id=current_org.organization_id
)
for thought in thoughts:
workflow_run_events.append(
WorkflowRunEvent(
type=WorkflowRunEventType.thought,
thought=thought,
created_at=thought.created_at,
modified_at=thought.modified_at,
)
)
workflow_run_events.sort(key=lambda x: x.created_at)
return workflow_run_events


@base_router.get(
"/workflows/runs/{workflow_run_id}",
response_model=WorkflowRunStatusResponse,
Expand Down
45 changes: 45 additions & 0 deletions skyvern/forge/sdk/schemas/workflow_runs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from datetime import datetime
from enum import StrEnum
from typing import Any

from pydantic import BaseModel

from skyvern.forge.sdk.schemas.observers import ObserverThought
from skyvern.forge.sdk.workflow.models.block import BlockType
from skyvern.webeye.actions.actions import Action


class WorkflowRunBlock(BaseModel):
workflow_run_block_id: str = "placeholder"
workflow_run_id: str
parent_workflow_run_block_id: str | None = None
block_type: BlockType
label: str | None = None
title: str | None = None
status: str | None = None
output: dict | list | str | None = None
continue_on_failure: bool = False
task_id: str | None = None
url: str | None = None
navigation_goal: str | None = None
data_extraction_goal: str | None = None
data_schema: dict[str, Any] | list | str | None = None
terminate_criterion: str | None = None
complete_criterion: str | None = None
created_at: datetime
modified_at: datetime


class WorkflowRunEventType(StrEnum):
action = "action"
thought = "thought"
block = "block"


class WorkflowRunEvent(BaseModel):
type: WorkflowRunEventType
action: Action | None = None
thought: ObserverThought | None = None
block: WorkflowRunBlock | None = None
created_at: datetime
modified_at: datetime
Loading

0 comments on commit 58413db

Please sign in to comment.