diff --git a/browsergym/visualwebarena/requirements.txt b/browsergym/visualwebarena/requirements.txt index 8317960f..17f7e56a 100644 --- a/browsergym/visualwebarena/requirements.txt +++ b/browsergym/visualwebarena/requirements.txt @@ -2,3 +2,4 @@ browsergym-core==0.11.3 browsergym-webarena libvisualwebarena==0.0.14 requests +torch diff --git a/browsergym/visualwebarena/src/browsergym/visualwebarena/task.py b/browsergym/visualwebarena/src/browsergym/visualwebarena/task.py index 6032410b..b6b7f91e 100644 --- a/browsergym/visualwebarena/src/browsergym/visualwebarena/task.py +++ b/browsergym/visualwebarena/src/browsergym/visualwebarena/task.py @@ -4,10 +4,11 @@ import pathlib import tempfile import urllib.parse -from typing import Optional, Tuple +from typing import Literal, Optional, Tuple import playwright.sync_api import requests +import torch from browsergym.core.task import AbstractBrowserTask @@ -101,6 +102,7 @@ def __init__( task_id: Optional[int] = None, intent_template_id: Optional[int] = None, with_na_hint: bool = False, + eval_captioning_model_device: Literal["cpu", "cuda"] = "cpu", ) -> None: super().__init__(seed) @@ -112,6 +114,7 @@ def __init__( self.webarena_instance = VisualWebArenaInstance() self.config_file: str = None self.with_na_hint = with_na_hint + self.eval_captioning_model_device = eval_captioning_model_device # one and only one of task id and template id must be provided if (task_id is None) == (intent_template_id is None): @@ -161,6 +164,7 @@ def __init__( def setup(self, page: playwright.sync_api.Page) -> tuple[str, dict]: # import webarena on instanciation from visualwebarena.evaluation_harness.evaluators import evaluator_router + from visualwebarena.evaluation_harness.image_utils import get_captioning_fn # pick a task at random self.config = self.random.choice(self.task_configs) @@ -172,7 +176,16 @@ def setup(self, page: playwright.sync_api.Page) -> tuple[str, dict]: self.config_file = f.name # build the evaluator - self.evaluator = evaluator_router(self.config_file) + captioning_fn = get_captioning_fn( + device=self.eval_captioning_model_device, + dtype=( + torch.float16 + if (self.eval_captioning_model_device == "cuda" and torch.cuda.is_available()) + else torch.float32 + ), + model_name="Salesforce/blip2-flan-t5-xl", + ) + self.evaluator = evaluator_router(self.config_file, captioning_fn=captioning_fn) # reset instance if needed (classifieds domain only) if self.config.get("require_reset", False):