Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions src/gradient/resources/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,57 @@ def retrieve(
cast_to=AgentRetrieveResponse,
)

def wait_for_deployment(
self,
uuid: str,
*,
timeout: float = 120.0,
interval: float = 2.0,
raise_on_failed: bool = True,
) -> AgentRetrieveResponse:
"""Poll the agent deployment until it reaches `STATUS_RUNNING` or a terminal failed state.

Args:
uuid: Agent UUID to poll
timeout: Maximum seconds to wait before raising TimeoutError
interval: Seconds between polls
raise_on_failed: If True, raise RuntimeError when deployment enters a failed state

Returns:
The final `AgentRetrieveResponse` when status is `STATUS_RUNNING`.

Raises:
TimeoutError: if the timeout is exceeded
RuntimeError: if deployment enters a failed terminal state and `raise_on_failed` is True
"""
import time

if not uuid:
raise ValueError("Expected a non-empty value for `uuid`")

end = time.time() + timeout
failed_states = {"STATUS_FAILED", "STATUS_UNDEPLOYMENT_FAILED", "STATUS_DELETED"}

while True:
resp = self.retrieve(uuid)
agent = resp.agent
status = None
if agent and agent.deployment and agent.deployment.status:
status = agent.deployment.status

if status == "STATUS_RUNNING":
return resp

if status in failed_states:
if raise_on_failed:
raise RuntimeError(f"Agent {uuid} deployment entered failed state: {status}")
return resp

if time.time() >= end:
raise TimeoutError(f"Timed out waiting for agent {uuid} to be running")

time.sleep(interval)

def update(
self,
path_uuid: str,
Expand Down Expand Up @@ -792,6 +843,58 @@ async def retrieve(
cast_to=AgentRetrieveResponse,
)

async def wait_for_deployment(
self,
uuid: str,
*,
timeout: float = 120.0,
interval: float = 2.0,
raise_on_failed: bool = True,
) -> AgentRetrieveResponse:
"""Async poll until agent deployment reaches `STATUS_RUNNING` or a terminal failed state.

Args:
uuid: Agent UUID to poll
timeout: Maximum seconds to wait before raising TimeoutError
interval: Seconds between polls
raise_on_failed: If True, raise RuntimeError when deployment enters a failed state

Returns:
The final `AgentRetrieveResponse` when status is `STATUS_RUNNING`.

Raises:
TimeoutError: if the timeout is exceeded
RuntimeError: if deployment enters a failed terminal state and `raise_on_failed` is True
"""
import asyncio
import time

if not uuid:
raise ValueError("Expected a non-empty value for `uuid`")

end = time.time() + timeout
failed_states = {"STATUS_FAILED", "STATUS_UNDEPLOYMENT_FAILED", "STATUS_DELETED"}

while True:
resp = await self.retrieve(uuid)
agent = resp.agent
status = None
if agent and agent.deployment and agent.deployment.status:
status = agent.deployment.status

if status == "STATUS_RUNNING":
return resp

if status in failed_states:
if raise_on_failed:
raise RuntimeError(f"Agent {uuid} deployment entered failed state: {status}")
return resp

if time.time() >= end:
raise TimeoutError(f"Timed out waiting for agent {uuid} to be running")

await asyncio.sleep(interval)

async def update(
self,
path_uuid: str,
Expand Down