Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

show sandbox container names in sample header #896

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/inspect_ai/_display/core/panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from inspect_ai._util.constants import CONSOLE_DISPLAY_WIDTH
from inspect_ai._util.path import cwd_relative_path
from inspect_ai._util.registry import registry_unqualified_name

from .display import TaskProfile
from .rich import is_vscode_notebook, rich_theme
Expand Down Expand Up @@ -82,7 +83,7 @@ def task_title(profile: TaskProfile, show_model: bool) -> str:
eval_epochs = profile.eval_config.epochs or 1
epochs = f" x {profile.eval_config.epochs}" if eval_epochs > 1 else ""
samples = f"{profile.samples//eval_epochs:,}{epochs} sample{'s' if profile.samples > 1 else ''}"
title = f"{profile.name} ({samples})"
title = f"{registry_unqualified_name(profile.name)} ({samples})"
if show_model:
title = f"{title}: {profile.model}"
return title
Expand Down
168 changes: 147 additions & 21 deletions src/inspect_ai/_display/textual/widgets/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,29 @@
from rich.table import Table
from rich.text import Text
from textual.app import ComposeResult
from textual.containers import Horizontal, HorizontalGroup, VerticalGroup
from textual.containers import (
Horizontal,
HorizontalGroup,
Vertical,
VerticalGroup,
)
from textual.widget import Widget
from textual.widgets import Button, LoadingIndicator, OptionList, Static
from textual.widgets import (
Button,
Collapsible,
LoadingIndicator,
OptionList,
Static,
)
from textual.widgets.option_list import Option, Separator

from inspect_ai._util.registry import registry_unqualified_name
from inspect_ai.log._samples import ActiveSample
from inspect_ai.util._sandbox import (
SandboxConnection,
SandboxConnectionContainer,
SandboxConnectionShell,
)

from ...core.progress import progress_time
from .clock import Clock
Expand Down Expand Up @@ -52,12 +68,24 @@ def set_samples(self, samples: list[ActiveSample]) -> None:
self.query_one(SamplesList).set_samples(samples)

async def set_highlighted_sample(self, highlighted: int | None) -> None:
sample_info = self.query_one(SampleInfo)
transcript_view = self.query_one(TranscriptView)
sample_toolbar = self.query_one(SampleToolbar)
if highlighted is not None:
sample = self.query_one(SamplesList).sample_for_highlighted(highlighted)
if sample is not None:
await self.query_one(SampleInfo).sync_sample(sample)
await self.query_one(TranscriptView).sync_sample(sample)
await self.query_one(SampleToolbar).sync_sample(sample)
sample_info.display = True
transcript_view.display = True
sample_toolbar.display = True
await sample_info.sync_sample(sample)
await transcript_view.sync_sample(sample)
await sample_toolbar.sync_sample(sample)
return

# otherwise hide ui
sample_info.display = False
transcript_view.display = False
sample_toolbar.display = False


class SamplesList(OptionList):
Expand Down Expand Up @@ -151,35 +179,133 @@ class SampleInfo(Horizontal):
DEFAULT_CSS = """
SampleInfo {
color: $text-muted;
layout: grid;
grid-size: 2 1;
grid-columns: 1fr auto;
width: 100%;
}
#sample-info-model {
text-align: right;
SampleInfo Collapsible {
padding: 0;
border-top: none;
}
SampleInfo Collapsible CollapsibleTitle {
padding: 0;
color: $secondary;
&:hover {
background: $block-hover-background;
color: $primary;
}
&:focus {
background: $block-hover-background;
color: $primary;
}
}
SampleInfo Collapsible Contents {
padding: 1 0 1 2;
overflow-y: hidden;
overflow-x: auto;
}
SampleInfo Static {
width: 1fr;
background: $surface;
color: $secondary;
}
"""

def __init__(self) -> None:
super().__init__()
self._sample: ActiveSample | None = None
self._show_sandboxes = False

def compose(self) -> ComposeResult:
yield Static(id="sample-info-id")
yield Static(id="sample-info-model")
if self._sample is not None and len(self._sample.sandboxes) > 0:
with Collapsible(title=""):
yield SandboxesView()
else:
yield Static()

async def sync_sample(self, sample: ActiveSample | None) -> None:
info_id = cast(Static, self.query_one("#sample-info-id"))
info_model = cast(Static, self.query_one("#sample-info-model"))
# bail if we've already processed this sample
if self._sample == sample:
return

# set sample
self._sample = sample

# compute whether we should show connection and recompose as required
show_sandboxes = sample is not None and len(sample.sandboxes) > 0
if show_sandboxes != self._show_sandboxes:
await self.recompose()
self._show_sandboxes = show_sandboxes

if sample is not None:
self.display = True
id = Text.from_markup(
f"[bold]{registry_unqualified_name(sample.task)}[/bold] - id: {sample.sample.id} (epoch {sample.epoch})"
)
info_id.update(id)
model = Text.from_markup(sample.model)
info_model.update(model)
title = f"{registry_unqualified_name(sample.task)} (id: {sample.sample.id}, epoch {sample.epoch}): {sample.model}"
if show_sandboxes:
self.query_one(Collapsible).title = title
sandboxes = self.query_one(SandboxesView)
await sandboxes.sync_sandboxes(sample.sandboxes)
else:
self.query_one(Static).update(title)
else:
self.display = False


class SandboxesView(Vertical):
DEFAULT_CSS = """
SandboxesView {
padding: 0 0 1 0;
background: transparent;
height: auto;
}
SandboxesView Static {
background: transparent;
}
"""

def __init__(self) -> None:
super().__init__()

def compose(self) -> ComposeResult:
yield Static(id="sandboxes-caption", markup=True)
yield Vertical(id="sandboxes")
yield Static(
"[italic]Hold down Alt (or Option) to select text for copying[/italic]",
id="sandboxes-footer",
markup=True,
)

async def sync_sandboxes(self, sandboxes: dict[str, SandboxConnection]) -> None:
def sandbox_connection_type() -> str:
connection = list(sandboxes.values())[0]
if isinstance(connection, SandboxConnectionShell):
return "directories"
elif isinstance(connection, SandboxConnectionContainer):
return "containers"
else:
return "hosts"

def sandbox_connection_target(sandbox: SandboxConnection) -> str:
if isinstance(sandbox, SandboxConnectionShell):
target = sandbox.working_dir
elif isinstance(sandbox, SandboxConnectionContainer):
target = sandbox.container
else:
target = sandbox.destination
return target.strip()

caption = cast(Static, self.query_one("#sandboxes-caption"))
caption.update(f"[bold]sandbox {sandbox_connection_type()}:[/bold]")

sandboxes_widget = self.query_one("#sandboxes")
sandboxes_widget.styles.margin = (
(0, 0, 1, 0) if len(sandboxes) > 1 else (0, 0, 0, 0)
)
await sandboxes_widget.remove_children()
await sandboxes_widget.mount_all(
[
Static(sandbox_connection_target(sandbox))
for sandbox in sandboxes.values()
]
)


class SampleToolbar(Horizontal):
DEFAULT_CSS = """
SampleToolbar Button {
Expand Down
16 changes: 7 additions & 9 deletions src/inspect_ai/_eval/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from inspect_ai.log._condense import condense_sample
from inspect_ai.log._file import eval_log_json
from inspect_ai.log._log import EvalSampleLimit, EvalSampleReductions, eval_error
from inspect_ai.log._samples import ActiveSample, active_sample
from inspect_ai.log._samples import active_sample
from inspect_ai.log._transcript import (
ErrorEvent,
SampleInitEvent,
Expand Down Expand Up @@ -432,14 +432,12 @@ def handle_error(ex: BaseException) -> EvalError:
semaphore_cm,
sandboxenv_cm,
active_sample(
ActiveSample(
task_name,
str(state.model),
sample,
state.epoch,
fails_on_error,
sample_transcript,
)
task=task_name,
model=str(state.model),
sample=sample,
epoch=state.epoch,
fails_on_error=fails_on_error,
transcript=sample_transcript,
) as active,
):
error: EvalError | None = None
Expand Down
3 changes: 2 additions & 1 deletion src/inspect_ai/_eval/task/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from inspect_ai._eval.task.task import Task
from inspect_ai._eval.task.util import task_run_dir
from inspect_ai._util.file import file, filesystem
from inspect_ai._util.registry import registry_unqualified_name
from inspect_ai._util.url import data_uri_to_base64, is_data_uri
from inspect_ai.dataset import Sample
from inspect_ai.util._sandbox.context import (
Expand Down Expand Up @@ -51,7 +52,7 @@ async def sandboxenv_context(
# initialize sandbox environment,
environments = await init_sandbox_environments_sample(
type=sandbox.type,
task_name=task_name,
task_name=registry_unqualified_name(task_name),
config=sandbox.config,
files=files,
setup=setup,
Expand Down
32 changes: 27 additions & 5 deletions src/inspect_ai/log/_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from shortuuid import uuid

from inspect_ai.dataset._dataset import Sample
from inspect_ai.util._sandbox import SandboxConnection
from inspect_ai.util._sandbox.context import sandbox_connections

from ._transcript import Transcript

Expand All @@ -19,6 +21,7 @@ def __init__(
epoch: int,
fails_on_error: bool,
transcript: Transcript,
sandboxes: dict[str, SandboxConnection],
) -> None:
self.id = uuid()
self.started = datetime.now().timestamp()
Expand All @@ -29,6 +32,7 @@ def __init__(
self.epoch = epoch
self.fails_on_error = fails_on_error
self.transcript = transcript
self.sandboxes = sandboxes
self._sample_task = asyncio.current_task()
self._interrupt_action: Literal["score", "error"] | None = None

Expand All @@ -54,13 +58,31 @@ def init_active_samples() -> None:


@contextlib.asynccontextmanager
async def active_sample(sample: ActiveSample) -> AsyncGenerator[ActiveSample, None]:
_active_samples.append(sample)
async def active_sample(
task: str,
model: str,
sample: Sample,
epoch: int,
fails_on_error: bool,
transcript: Transcript,
) -> AsyncGenerator[ActiveSample, None]:
# create the sample
active = ActiveSample(
task=task,
model=model,
sample=sample,
epoch=epoch,
sandboxes=await sandbox_connections(),
fails_on_error=fails_on_error,
transcript=transcript,
)

_active_samples.append(active)
try:
yield sample
yield active
finally:
sample.completed = datetime.now().timestamp()
_active_samples.remove(sample)
active.completed = datetime.now().timestamp()
_active_samples.remove(active)


def active_samples() -> list[ActiveSample]:
Expand Down
8 changes: 8 additions & 0 deletions src/inspect_ai/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from ._resource import resource
from ._sandbox import (
OutputLimitExceededError,
SandboxConnection,
SandboxConnectionContainer,
SandboxConnectionShell,
SandboxConnectionSSH,
SandboxEnvironment,
SandboxEnvironmentLimits,
SandboxEnvironments,
Expand Down Expand Up @@ -32,6 +36,10 @@
"SandboxEnvironments",
"SandboxEnvironmentSpec",
"SandboxEnvironmentType",
"SandboxConnection",
"SandboxConnectionContainer",
"SandboxConnectionShell",
"SandboxConnectionSSH",
"sandboxenv",
"sandbox",
"sandbox_with",
Expand Down
8 changes: 8 additions & 0 deletions src/inspect_ai/util/_sandbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from .context import sandbox, sandbox_with
from .docker.docker import DockerSandboxEnvironment # noqa: F401
from .environment import (
SandboxConnection,
SandboxConnectionContainer,
SandboxConnectionShell,
SandboxConnectionSSH,
SandboxEnvironment,
SandboxEnvironments,
SandboxEnvironmentSpec,
Expand All @@ -19,6 +23,10 @@
"SandboxEnvironments",
"SandboxEnvironmentSpec",
"SandboxEnvironmentType",
"SandboxConnection",
"SandboxConnectionContainer",
"SandboxConnectionShell",
"SandboxConnectionSSH",
"sandboxenv",
"sandbox",
"sandbox_with",
Expand Down
Loading
Loading