Skip to content

Commit

Permalink
prepare_backend update for weblinx
Browse files Browse the repository at this point in the history
  • Loading branch information
gasse committed Nov 1, 2024
1 parent 00e0c6b commit 9487948
Showing 1 changed file with 10 additions and 51 deletions.
61 changes: 10 additions & 51 deletions browsergym/experiments/src/browsergym/experiments/benchmark/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,60 +149,19 @@ def prepare_backends(self):
import browsergym.assistantbench

case "weblinx":
import json
import zipfile

# download weblinx ressources from huggingface hub
import huggingface_hub

cache_dir = os.getenv("BROWSERGYM_WEBLINX_CACHE_DIR", "./bg_wl_data")
cache_dir = Path(cache_dir).expanduser()
base_demo_dir = cache_dir / "demonstrations"
base_zip_dir = cache_dir / "demonstrations_zip"

metadata_path = cache_dir / "metadata.json"

if not metadata_path.exists():
logger.info(f"Downloading metadata.json")
huggingface_hub.snapshot_download(
repo_id="McGill-NLP/weblinx-browsergym",
repo_type="dataset",
local_dir=cache_dir,
allow_patterns=["metadata.json"],
)

with open(metadata_path, "r") as f:
metadata = json.load(f)

for split, meta_split_dict in metadata.items():
for demo_id, steps in meta_split_dict.items():
for step_num, task_dict in steps.items():
if task_dict["is_task"] is False:
pass

# if the base_dir / project_id does not exist, download the dataset
if not base_demo_dir.joinpath(demo_id).exists():
# first, if zip file does not exist, download the zip file
if not base_zip_dir.joinpath(f"{demo_id}.zip").exists():
logger.info(f"Downloading demonstrations_zip/{demo_id}.zip")
huggingface_hub.snapshot_download(
repo_id="McGill-NLP/weblinx-browsergym",
repo_type="dataset",
local_dir=cache_dir,
allow_patterns=[f"demonstrations_zip/{demo_id}.zip"],
)

# then, unzip the file
with zipfile.ZipFile(
base_zip_dir.joinpath(f"{demo_id}.zip"), "r"
) as zip_ref:
zip_ref.extractall(base_demo_dir.joinpath(demo_id))

# register environments
import weblinx_browsergym

weblinx_browsergym.register_weblinx_tasks(split="train")
weblinx_browsergym.register_weblinx_tasks(split="valid")
# pre-download weblinx files
cache_dir = os.environ.get("BROWSERGYM_WEBLINX_CACHE_DIR", None)

assert (
cache_dir
), f"Environment variable BROWSERGYM_WEBLINX_CACHE_DIR is missing or empty, required to prepare the weblinx backend."

tasks = weblinx_browsergym.list_tasks(split="test_iid", cache_dir=cache_dir)
demo_ids = weblinx_browsergym.get_unique_demo_ids(tasks)
weblinx_browsergym.download_and_unzip_demos(demo_ids)

case _:
raise ValueError(f"Unknown benchmark backend {repr(backend)}")
Expand Down

0 comments on commit 9487948

Please sign in to comment.