-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
159 additions
and
109 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
1 change: 1 addition & 0 deletions
1
browsergym/experiments/src/browsergym/experiments/benchmark/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .base import DEFAULT_BENCHMARKS, Benchmark, HighLevelActionSetArgs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
24 changes: 24 additions & 0 deletions
24
browsergym/experiments/src/browsergym/experiments/benchmark/metadata/utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import io | ||
import pkgutil | ||
|
||
import pandas as pd | ||
|
||
|
||
def task_metadata(benchmark_name: str): | ||
return task_metadata_from_csv( | ||
io.StringIO(pkgutil.get_data(__name__, f"{benchmark_name}.csv").decode("utf-8")) | ||
) | ||
|
||
|
||
def task_metadata_from_csv(filepath): | ||
return pd.read_csv(filepath).fillna("") | ||
|
||
|
||
def task_list_from_metadata(metadata: pd.DataFrame, filter: dict[str, str] = {}): | ||
df = metadata | ||
# filter the desired columns (AND filter) | ||
for col_name, regex in filter.items(): | ||
col_filter = df[col_name].astype(str).str.contains(regex, regex=True) | ||
df = df[col_filter] | ||
# return only the task names | ||
return list(df["task_name"]) |
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
95 changes: 95 additions & 0 deletions
95
browsergym/experiments/src/browsergym/experiments/benchmark/utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import io | ||
import pkgutil | ||
from typing import Literal | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from browsergym.experiments.loop import SEED_MAX, EnvArgs | ||
|
||
|
||
def make_env_args_list_from_workarena_curriculum( | ||
level: Literal["l1", "l2", "l3"], | ||
task_category_filter: str, | ||
meta_seed: int, | ||
max_steps: int, | ||
curriculum_type: Literal["human", "agent"], | ||
seeds_l1: int = 10, | ||
): | ||
""" | ||
Returns a WorkArena predefined task curriculum (e.g., task and seed combination). | ||
""" | ||
assert level in ("l1", "l2", "l3") | ||
assert curriculum_type in ("human", "agent") | ||
|
||
env_args_list = [] | ||
|
||
# dynamic import | ||
from browsergym.workarena import get_all_tasks_agents | ||
|
||
all_task_tuples = get_all_tasks_agents( | ||
filter=f"{level}.{task_category_filter}" if task_category_filter else level, | ||
meta_seed=meta_seed, | ||
is_agent_curriculum=(curriculum_type == "agent"), | ||
n_seed_l1=seeds_l1, | ||
) | ||
|
||
for task, seed in all_task_tuples: | ||
task_name = task.get_task_id() | ||
env_args_list.append(EnvArgs(task_name=task_name, task_seed=seed, max_steps=max_steps)) | ||
|
||
return env_args_list | ||
|
||
|
||
def make_env_args_list_from_repeat_tasks( | ||
task_list: list[str], max_steps: int, n_repeats: int, seeds_rng: np.random.RandomState | ||
): | ||
""" | ||
Generates a list of `len(task_list)` time `n_repeats` environments arguments, using randomly generated seeds. | ||
""" | ||
env_args_list = [] | ||
for task in task_list: | ||
for seed in seeds_rng.randint(low=0, high=SEED_MAX, size=n_repeats): | ||
env_args_list.append( | ||
EnvArgs( | ||
task_name=task, | ||
task_seed=int(seed), | ||
max_steps=max_steps, | ||
headless=True, | ||
record_video=False, | ||
wait_for_user_message=False, | ||
viewport=None, | ||
slow_mo=None, | ||
storage_state=None, | ||
task_kwargs=None, | ||
) | ||
) | ||
|
||
return env_args_list | ||
|
||
|
||
def make_env_args_list_from_fixed_seeds( | ||
task_list: list[str], max_steps: int, fixed_seeds: list[int] | ||
): | ||
""" | ||
Generates a list of `len(task_list)` time `n_repeats` environments arguments, using randomly generated seeds. | ||
""" | ||
env_args_list = [] | ||
for task in task_list: | ||
for seed in fixed_seeds: | ||
env_args_list.append( | ||
EnvArgs( | ||
task_name=task, | ||
task_seed=int(seed), | ||
max_steps=max_steps, | ||
headless=True, | ||
record_video=False, | ||
wait_for_user_message=False, | ||
viewport=None, | ||
slow_mo=None, | ||
storage_state=None, | ||
task_kwargs=None, | ||
) | ||
) | ||
|
||
return env_args_list |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.