Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new API to the workflow run page #1400

Merged
merged 6 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""add index to ObserverThoughtModel db table

Revision ID: cf45479f484c
Revises: 411dd89f3df9
Create Date: 2024-12-17 06:51:04.086890+00:00

"""

from typing import Sequence, Union

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "cf45479f484c"
down_revision: Union[str, None] = "411dd89f3df9"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_index(
"observer_cruise_index", "observer_thoughts", ["organization_id", "observer_cruise_id"], unique=False
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("observer_cruise_index", table_name="observer_thoughts")
# ### end Alembic commands ###
2 changes: 1 addition & 1 deletion skyvern/forge/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@
DecisiveAction,
UserDefinedError,
WebAction,
parse_actions,
)
from skyvern.webeye.actions.caching import retrieve_action_plan
from skyvern.webeye.actions.handler import ActionHandler, poll_verification_code
from skyvern.webeye.actions.models import AgentStepOutput, DetailedAgentStepOutput
from skyvern.webeye.actions.parse_actions import parse_actions
from skyvern.webeye.actions.responses import ActionResult
from skyvern.webeye.browser_factory import BrowserState
from skyvern.webeye.scraper.scraper import ElementTreeFormat, ScrapedPage, scrape_website
Expand Down
35 changes: 35 additions & 0 deletions skyvern/forge/sdk/db/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,25 @@ async def get_task_actions(self, task_id: str, organization_id: str | None = Non
LOG.error("UnexpectedError", exc_info=True)
raise

async def get_tasks_actions(self, task_ids: list[str], organization_id: str | None = None) -> list[Action]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handle the case where task_ids is an empty list to avoid unnecessary database queries. This is applicable to similar methods as well.

try:
async with self.Session() as session:
query = (
select(ActionModel)
.filter(ActionModel.organization_id == organization_id)
.filter(ActionModel.task_id.in_(task_ids))
.order_by(ActionModel.created_at)
)
actions = (await session.scalars(query)).all()
return [Action.model_validate(action) for action in actions]

except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise

async def get_first_step(self, task_id: str, organization_id: str | None = None) -> Step | None:
try:
async with self.Session() as session:
Expand Down Expand Up @@ -1789,6 +1808,22 @@ async def get_observer_thought(
return ObserverThought.model_validate(observer_thought)
return None

async def get_observer_cruise_thoughts(
self,
observer_cruise_id: str,
organization_id: str | None = None,
) -> list[ObserverThought]:
async with self.Session() as session:
observer_thoughts = (
await session.scalars(
select(ObserverThoughtModel)
.filter_by(observer_cruise_id=observer_cruise_id)
.filter_by(organization_id=organization_id)
.order_by(ObserverThoughtModel.created_at)
)
).all()
return [ObserverThought.model_validate(thought) for thought in observer_thoughts]

async def create_observer_cruise(
self,
workflow_run_id: str | None = None,
Expand Down
1 change: 1 addition & 0 deletions skyvern/forge/sdk/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,7 @@ class ObserverCruiseModel(Base):

class ObserverThoughtModel(Base):
__tablename__ = "observer_thoughts"
__table_args__ = (Index("observer_cruise_index", "organization_id", "observer_cruise_id"),)

observer_thought_id = Column(String, primary_key=True, default=generate_observer_thought_id)
organization_id = Column(String, ForeignKey("organizations.organization_id"), nullable=True)
Expand Down
89 changes: 88 additions & 1 deletion skyvern/forge/sdk/routes/agent_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.permissions.permission_checker_factory import PermissionCheckerFactory
from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskType
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.organizations import (
Expand All @@ -51,12 +51,14 @@
TaskResponse,
TaskStatus,
)
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock, WorkflowRunEvent, WorkflowRunEventType
from skyvern.forge.sdk.services import org_auth_service
from skyvern.forge.sdk.workflow.exceptions import (
FailedToCreateWorkflow,
FailedToUpdateWorkflow,
WorkflowParameterMissingRequiredValue,
)
from skyvern.forge.sdk.workflow.models.block import BlockType
from skyvern.forge.sdk.workflow.models.workflow import (
RunWorkflowResponse,
Workflow,
Expand Down Expand Up @@ -643,6 +645,91 @@ async def get_workflow_run(
)


@base_router.get(
"/workflows/{workflow_id}/runs/{workflow_run_id}/events",
)
@base_router.get(
"/workflows/{workflow_id}/runs/{workflow_run_id}/events/",
)
async def get_workflow_run_events(
workflow_id: str,
workflow_run_id: str,
observer_cruise_id: str | None = None,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1),
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> list[WorkflowRunEvent]:
# get all the tasks for the workflow run
tasks = await app.DATABASE.get_tasks(
page,
page_size,
workflow_run_id=workflow_run_id,
organization_id=current_org.organization_id,
)
workflow_run_events: list[WorkflowRunEvent] = []
for task in tasks:
block_type = BlockType.TASK
if task.task_type == TaskType.general:
if not task.navigation_goal and task.data_extraction_goal:
block_type = BlockType.EXTRACTION
elif task.navigation_goal and not task.data_extraction_goal:
block_type = BlockType.NAVIGATION
elif task.task_type == TaskType.validation:
block_type = BlockType.VALIDATION
elif task.task_type == TaskType.action:
block_type = BlockType.ACTION
event = WorkflowRunEvent(
type=WorkflowRunEventType.block,
block=WorkflowRunBlock(
workflow_run_id=workflow_run_id,
block_type=block_type,
label=task.title,
title=task.title,
url=task.url,
status=task.status,
navigation_goal=task.navigation_goal,
data_extraction_goal=task.data_extraction_goal,
data_schema=task.extracted_information_schema,
terminate_criterion=task.terminate_criterion,
complete_criterion=task.complete_criterion,
created_at=task.created_at,
modified_at=task.modified_at,
),
created_at=task.created_at,
modified_at=task.modified_at,
)
workflow_run_events.append(event)
# get all the actions for all the tasks
actions = await app.DATABASE.get_tasks_actions(
[task.task_id for task in tasks], organization_id=current_org.organization_id
)
for action in actions:
workflow_run_events.append(
WorkflowRunEvent(
type=WorkflowRunEventType.action,
action=action,
created_at=action.created_at or datetime.datetime.utcnow(),
modified_at=action.modified_at or datetime.datetime.utcnow(),
)
)
# get all the thoughts for the cruise
if observer_cruise_id:
thoughts = await app.DATABASE.get_observer_cruise_thoughts(
observer_cruise_id, organization_id=current_org.organization_id
)
for thought in thoughts:
workflow_run_events.append(
WorkflowRunEvent(
type=WorkflowRunEventType.thought,
thought=thought,
created_at=thought.created_at,
modified_at=thought.modified_at,
)
)
workflow_run_events.sort(key=lambda x: x.created_at)
return workflow_run_events


@base_router.get(
"/workflows/runs/{workflow_run_id}",
response_model=WorkflowRunStatusResponse,
Expand Down
45 changes: 45 additions & 0 deletions skyvern/forge/sdk/schemas/workflow_runs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from datetime import datetime
from enum import StrEnum
from typing import Any

from pydantic import BaseModel

from skyvern.forge.sdk.schemas.observers import ObserverThought
from skyvern.forge.sdk.workflow.models.block import BlockType
from skyvern.webeye.actions.actions import Action


class WorkflowRunBlock(BaseModel):
workflow_run_block_id: str = "placeholder"
workflow_run_id: str
parent_workflow_run_block_id: str | None = None
block_type: BlockType
label: str | None = None
title: str | None = None
status: str | None = None
output: dict | list | str | None = None
continue_on_failure: bool = False
task_id: str | None = None
url: str | None = None
navigation_goal: str | None = None
data_extraction_goal: str | None = None
data_schema: dict[str, Any] | list | str | None = None
terminate_criterion: str | None = None
complete_criterion: str | None = None
created_at: datetime
modified_at: datetime


class WorkflowRunEventType(StrEnum):
action = "action"
thought = "thought"
block = "block"


class WorkflowRunEvent(BaseModel):
type: WorkflowRunEventType
action: Action | None = None
thought: ObserverThought | None = None
block: WorkflowRunBlock | None = None
created_at: datetime
modified_at: datetime
Loading
Loading