Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ select = ["I"]
case-sensitive = false

[tool.uv]
required-version = ">=0.6.15"
required-version = ">=0.6.13"
dev-dependencies = [
"black>=25.1.0",
"ipykernel>=6.29.5",
Expand Down
11 changes: 1 addition & 10 deletions src/art/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi import Body, FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse


from . import dev
from .errors import ARTError
from .local import LocalBackend
Expand All @@ -34,16 +35,6 @@ def is_port_available(port: int) -> bool:
)
return

# Reset the custom __new__ and __init__ methods for TrajectoryGroup
def __new__(cls, *args: Any, **kwargs: Any) -> TrajectoryGroup:
return pydantic.BaseModel.__new__(cls)

def __init__(self, *args: Any, **kwargs: Any) -> None:
return pydantic.BaseModel.__init__(self, *args, **kwargs)

TrajectoryGroup.__new__ = __new__ # type: ignore
TrajectoryGroup.__init__ = __init__

backend = LocalBackend()
app = FastAPI()

Expand Down
82 changes: 35 additions & 47 deletions src/art/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,7 @@ class TrajectoryGroup(pydantic.BaseModel):

def __init__(
self,
trajectories: (
Iterable[Trajectory | BaseException] | Iterable[Awaitable[Trajectory]]
),
trajectories: Iterable[Trajectory | BaseException],
*,
exceptions: list[BaseException] = [],
) -> None:
Expand Down Expand Up @@ -176,53 +174,43 @@ def __new__(

def __new__(
cls,
trajectories: (
Iterable[Trajectory | BaseException] | Iterable[Awaitable[Trajectory]]
),
trajectories: Iterable,
*,
exceptions: list[BaseException] = [],
) -> "TrajectoryGroup | Awaitable[TrajectoryGroup]":
):
ts = list(trajectories)
if any(hasattr(t, "__await__") for t in ts):

async def _(exceptions: list[BaseException]):
from .gather import get_gather_context, record_metrics

context = get_gather_context()
trajectories = []
for future in asyncio.as_completed(
cast(list[Awaitable[Trajectory]], ts)
):
try:
trajectory = await future
trajectories.append(trajectory)
record_metrics(context, trajectory)
context.update_pbar(n=1)
except BaseException as e:
exceptions.append(e)
context.metric_sums["exceptions"] += 1
context.update_pbar(n=0)
if context.too_many_exceptions():
raise
return TrajectoryGroup(
trajectories=trajectories,
exceptions=exceptions,
)
if not ts or isinstance(ts[0], (Trajectory, BaseException)):
return super().__new__(cls)

async def _(exceptions: list[BaseException]):
from .gather import get_gather_context, record_metrics

context = get_gather_context()
trajectories = []
for future in asyncio.as_completed(cast(list[Awaitable[Trajectory]], ts)):
try:
trajectory = await future
trajectories.append(trajectory)
record_metrics(context, trajectory)
context.update_pbar(n=1)
except BaseException as e:
exceptions.append(e)
context.metric_sums["exceptions"] += 1
context.update_pbar(n=0)
if context.too_many_exceptions():
raise
return TrajectoryGroup(
trajectories=trajectories,
exceptions=exceptions,
)

class CoroutineWithMetadata:
def __init__(self, coro, num_trajectories):
self.coro = coro
self._num_trajectories = num_trajectories
class CoroutineWithMetadata:
def __init__(self, coro, num_trajectories):
self.coro = coro
self._num_trajectories = num_trajectories

def __await__(self):
return self.coro.__await__()
def __await__(self):
return self.coro.__await__()

coro = _(exceptions.copy())
return CoroutineWithMetadata(coro, len(ts))
else:
group = super().__new__(cls)
group.__init__(
trajectories=cast(list[Trajectory | BaseException], ts),
exceptions=exceptions,
)
return group
coro = _(exceptions.copy())
return CoroutineWithMetadata(coro, len(ts))