Skip to content

Commit

Permalink
captioning_fn fix
Browse files Browse the repository at this point in the history
  • Loading branch information
gasse committed Nov 4, 2024
1 parent 385ccb6 commit 1ed9a7a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
1 change: 1 addition & 0 deletions browsergym/visualwebarena/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ browsergym-core==0.11.3
browsergym-webarena
libvisualwebarena==0.0.14
requests
torch
17 changes: 15 additions & 2 deletions browsergym/visualwebarena/src/browsergym/visualwebarena/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import pathlib
import tempfile
import urllib.parse
from typing import Optional, Tuple
from typing import Literal, Optional, Tuple

import playwright.sync_api
import requests
import torch

from browsergym.core.task import AbstractBrowserTask

Expand Down Expand Up @@ -101,6 +102,7 @@ def __init__(
task_id: Optional[int] = None,
intent_template_id: Optional[int] = None,
with_na_hint: bool = False,
eval_captioning_model_device: Literal["cpu", "cuda"] = "cpu",
) -> None:
super().__init__(seed)

Expand All @@ -112,6 +114,7 @@ def __init__(
self.webarena_instance = VisualWebArenaInstance()
self.config_file: str = None
self.with_na_hint = with_na_hint
self.eval_captioning_model_device = eval_captioning_model_device

# one and only one of task id and template id must be provided
if (task_id is None) == (intent_template_id is None):
Expand Down Expand Up @@ -161,6 +164,7 @@ def __init__(
def setup(self, page: playwright.sync_api.Page) -> tuple[str, dict]:
# import webarena on instanciation
from visualwebarena.evaluation_harness.evaluators import evaluator_router
from visualwebarena.evaluation_harness.image_utils import get_captioning_fn

# pick a task at random
self.config = self.random.choice(self.task_configs)
Expand All @@ -172,7 +176,16 @@ def setup(self, page: playwright.sync_api.Page) -> tuple[str, dict]:
self.config_file = f.name

# build the evaluator
self.evaluator = evaluator_router(self.config_file)
captioning_fn = get_captioning_fn(
device=self.eval_captioning_model_device,
dtype=(
torch.float16
if (self.eval_captioning_model_device == "cuda" and torch.cuda.is_available())
else torch.float32
),
model_name="Salesforce/blip2-flan-t5-xl",
)
self.evaluator = evaluator_router(self.config_file, captioning_fn=captioning_fn)

# reset instance if needed (classifieds domain only)
if self.config.get("require_reset", False):
Expand Down

0 comments on commit 1ed9a7a

Please sign in to comment.