Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmark.prepare_backend() #204

Merged
merged 6 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import fnmatch
import logging
import typing
from dataclasses import dataclass, field
from typing import Literal, Optional

Expand Down Expand Up @@ -49,12 +50,16 @@ def make_action_set(self):
)


BenchmarkBackend = Literal["miniwob", "webarena", "visualwebarena", "workarena", "assistantbench"]


@dataclass
class Benchmark(DataClassJsonMixin):
name: str
high_level_action_set_args: HighLevelActionSetArgs
is_multi_tab: bool
env_args_list: list[EnvArgs]
backends: list[BenchmarkBackend]
task_metadata: Optional[pd.DataFrame] = field(
default_factory=lambda: None,
metadata=config(
Expand All @@ -73,6 +78,57 @@ def __post_init__(self):
# make sure all tasks in env_args are in the metadata
metadata_tasks = list(self.task_metadata["task_name"])
assert all([env_args.task_name in metadata_tasks for env_args in self.env_args_list])
# check backend values
assert all([backend in typing.get_args(BenchmarkBackend) for backend in self.backends])

def prepare_backends(self):
for backend in self.backends:
match backend:
case "miniwob":
# register environments
import browsergym.miniwob

# check setup
browsergym.miniwob.environment_variables_precheck()

case "webarena":
# register environments
import browsergym.webarena

# full reset the instance (requires environment variables properly set up)
from browsergym.webarena.instance import WebArenaInstance

default_instance = WebArenaInstance()
default_instance.full_reset()

case "visualwebarena":
# register environments
import browsergym.visualwebarena

# full reset the instance (requires environment variables properly set up)
from browsergym.visualwebarena.instance import (
VisualWebArenaInstance,
)

default_instance = VisualWebArenaInstance()
default_instance.full_reset()

case "workarena":
# register environments
import browsergym.workarena

# check server status
from browsergym.workarena.instance import SNowInstance

default_instance = SNowInstance()
default_instance.check_status()

case "assistantbench":
# register environments
import browsergym.assistantbench

case _:
raise ValueError(f"Unknown benchmark backend {repr(backend)}")

def subset_from_split(self, split: Literal["train", "valid", "test"]):
split_column = "browsergym_split"
Expand Down Expand Up @@ -106,6 +162,7 @@ def subset_from_regexp(self, column, regexp):
name=f"{self.name}[{column}=/{regexp}/]",
high_level_action_set_args=self.high_level_action_set_args,
is_multi_tab=self.is_multi_tab,
backends=self.backends,
env_args_list=[
env_args
for env_args in self.env_args_list
Expand Down Expand Up @@ -191,6 +248,7 @@ def subset_from_regexp(self, column, regexp):
name="miniwob",
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["miniwob_all"],
is_multi_tab=False,
backends=["miniwob"],
env_args_list=make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(metadata=task_metadata("miniwob")),
max_steps=10,
Expand All @@ -203,6 +261,7 @@ def subset_from_regexp(self, column, regexp):
name="miniwob_tiny_test",
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["miniwob_all"],
is_multi_tab=False,
backends=["miniwob"],
env_args_list=make_env_args_list_from_repeat_tasks(
task_list=["miniwob.click-dialog", "miniwob.click-checkboxes"],
max_steps=5,
Expand All @@ -215,6 +274,7 @@ def subset_from_regexp(self, column, regexp):
name="webarena",
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["webarena"],
is_multi_tab=True,
backends=["webarena"],
env_args_list=make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(metadata=task_metadata("webarena")),
max_steps=15,
Expand All @@ -227,6 +287,7 @@ def subset_from_regexp(self, column, regexp):
name="visualwebarena",
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["visualwebarena"],
is_multi_tab=True,
backends=["visualwebarena"],
env_args_list=make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(metadata=task_metadata("visualwebarena")),
max_steps=15,
Expand All @@ -239,6 +300,7 @@ def subset_from_regexp(self, column, regexp):
name="workarena_l1",
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["workarena"],
is_multi_tab=False,
backends=["workarena"],
env_args_list=make_env_args_list_from_workarena_curriculum(
level="l1",
task_category_filter=None,
Expand All @@ -253,6 +315,7 @@ def subset_from_regexp(self, column, regexp):
name="workarena_l2_agent_curriculum_eval",
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["workarena++"],
is_multi_tab=True,
backends=["workarena"],
env_args_list=make_env_args_list_from_workarena_curriculum(
level="l2",
task_category_filter=None,
Expand All @@ -266,6 +329,7 @@ def subset_from_regexp(self, column, regexp):
name="workarena_l3_agent_curriculum_eval",
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["workarena++"],
is_multi_tab=True,
backends=["workarena"],
env_args_list=make_env_args_list_from_workarena_curriculum(
level="l3",
task_category_filter=None,
Expand All @@ -279,6 +343,7 @@ def subset_from_regexp(self, column, regexp):
name="assistantbench",
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["assistantbench"],
is_multi_tab=True,
backends=["assistantbench"],
env_args_list=make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(
metadata=task_metadata("assistantbench"), filter={"browsergym_split": "valid|test"}
Expand Down
10 changes: 5 additions & 5 deletions browsergym/experiments/src/browsergym/experiments/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@
@dataclass
class EnvArgs(DataClassJsonMixin):
task_name: str
task_seed: int = None
max_steps: int = None
task_seed: Optional[int] = None
gasse marked this conversation as resolved.
Show resolved Hide resolved
max_steps: Optional[int] = None
headless: bool = True
record_video: bool = False
wait_for_user_message: bool = False
viewport: dict = None # use default value from BrowserGym
slow_mo: int = None # use default value from BrowserGym
viewport: Optional[dict] = None # use default value from BrowserGym
slow_mo: Optional[int] = None # use default value from BrowserGym
storage_state: Optional[str | Path | dict] = None
task_kwargs: dict = None # use default value from BrowserGym
task_kwargs: Optional[dict] = None # use default value from BrowserGym

def make_env(self, action_mapping, exp_dir):
extra_kwargs = {}
Expand Down
9 changes: 9 additions & 0 deletions browsergym/miniwob/src/browsergym/miniwob/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import os

from browsergym.core.registration import register_task

from . import all


def environment_variables_precheck():
assert os.environ.get(
"MINIWOB_URL", None
), "Environment variable MINIWOB_URL has not been setup."


ALL_MINIWOB_TASKS = [
all.AscendingNumbersTask,
all.BisectAngleTask,
Expand Down
1 change: 1 addition & 0 deletions browsergym/miniwob/src/browsergym/miniwob/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional, Tuple

import playwright.sync_api

from browsergym.core.task import AbstractBrowserTask


Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import playwright.sync_api
import logging
import os

import playwright.sync_api
import requests

logger = logging.getLogger(__name__)


ENV_VARS = ("SHOPPING", "REDDIT", "WIKIPEDIA", "HOMEPAGE", "CLASSIFIEDS", "CLASSIFIEDS_RESET_TOKEN")

Expand Down Expand Up @@ -30,12 +34,12 @@ def __init__(
# import webarena on instantiation
from visualwebarena.browser_env.env_config import (
ACCOUNTS,
CLASSIFIEDS,
CLASSIFIEDS_RESET_TOKEN,
HOMEPAGE,
REDDIT,
SHOPPING,
WIKIPEDIA,
HOMEPAGE,
CLASSIFIEDS,
CLASSIFIEDS_RESET_TOKEN,
)

self.urls = {
Expand All @@ -49,6 +53,29 @@ def __init__(

self.credentials = ACCOUNTS

def full_reset(self):
reset_url = os.environ.get("VWA_FULL_RESET", None)

assert (
reset_url
), f"Environment variable VWA_FULL_RESET is missing or empty, required for a full instance reset."

# Send the GET request to trigger the reset script
logger.info(f"VisualWebArena full instance reset in progress.")

# 5 minutes timeout (takes 2-3 minutes in practice)
# https://requests.readthedocs.io/en/stable/user/advanced/#timeouts
response = requests.get(reset_url, timeout=(3.05, 5 * 60))

# Print the response from the server
logger.info(f"Reset status code: {response.status_code}")
logger.info(f"Reset response: {response.text}")

if not response.status_code == 200:
raise Exception(
f"Full instance reset failed ({response.status_code}): {response.status_code}"
)

def check_status(self):
"""
Check the status of the instance. Raises an error if the instance is not ready to be used.
Expand Down
54 changes: 47 additions & 7 deletions browsergym/webarena/src/browsergym/webarena/instance.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import playwright.sync_api
import logging
import os

import playwright.sync_api
import requests

logger = logging.getLogger(__name__)

ENV_VARS = ("SHOPPING", "SHOPPING_ADMIN", "REDDIT", "GITLAB", "WIKIPEDIA", "MAP", "HOMEPAGE")

Expand Down Expand Up @@ -29,13 +32,13 @@ def __init__(
# import webarena on instanciation
from webarena.browser_env.env_config import (
ACCOUNTS,
GITLAB,
HOMEPAGE,
MAP,
REDDIT,
SHOPPING,
SHOPPING_ADMIN,
GITLAB,
WIKIPEDIA,
MAP,
HOMEPAGE,
)

self.urls = {
Expand All @@ -50,21 +53,58 @@ def __init__(

self.credentials = ACCOUNTS

def full_reset(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

helper function instead of code duplicate?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean code duplicate? This code is new

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's 90% the same code on the side of visualwebarena

reset_url = os.environ.get("WA_FULL_RESET", None)

assert (
reset_url
), f"Environment variable WA_FULL_RESET is missing or empty, required for a full instance reset."

# Send the GET request to trigger the reset script
logger.info(f"WebArena full instance reset in progress.")

# 5 minutes timeout (takes 2-3 minutes in practice)
# https://requests.readthedocs.io/en/stable/user/advanced/#timeouts
response = requests.get(reset_url, timeout=(3.05, 5 * 60))

# Print the response from the server
logger.info(f"Reset status code: {response.status_code}")
logger.info(f"Reset response: {response.text}")

if not response.status_code == 200:
raise Exception(
f"Full instance reset failed ({response.status_code}): {response.status_code}"
)

# warm-start the instance (navigate to every domain)
retries_left = 3
while retries_left:
retries_left -= 1
try:
self._check_is_reachable(timeout=60) # 60 seconds, cold starting might be slow
break
except Exception as e:
if not retries_left:
raise
logger.info(
f"Instance unresponsive after reset, retrying ({retries_left} retries left)\n{e}"
)

def check_status(self):
"""
Check the status of the instance. Raises an error if the instance is not ready to be used.

"""
self._check_is_reachable()
self._check_is_reachable(timeout=10) # 10 seconds

def _check_is_reachable(self):
def _check_is_reachable(self, timeout: int):
"""
Test that every website is reachable.

"""
for site, url in self.urls.items():
try:
requests.get(url, timeout=5000) # 5 secs
requests.get(url, timeout=timeout)
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
raise RuntimeError(
f'WebArena site "{site}" ({url}) is not reacheable. Please check the URL.'
Expand Down
Loading