Skip to content

Commit

Permalink
Minor refactors (#255)
Browse files Browse the repository at this point in the history
* Optional method AbstractBrowserTask.teardown()

* browsergym registration refactor
  • Loading branch information
gasse authored Nov 14, 2024
1 parent 00429a7 commit 49558be
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 15 deletions.
52 changes: 44 additions & 8 deletions browsergym/core/src/browsergym/core/registration.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,38 @@
from functools import partial
from typing import Type

import gymnasium as gym

from .env import BrowserEnv
from .task import AbstractBrowserTask


class frozen_partial:
"""
Freeze some keyword arguments of a function.
"""

def __init__(self, func, **frozen_kwargs):
self.func = func
self.frozen_kwargs = frozen_kwargs

def __call__(self, *args, **kwargs):
# check overlap between kwargs and frozen_kwargs
clashing_kwargs = set(self.frozen_kwargs) & set(kwargs) # key set intersection
if clashing_kwargs:
raise ValueError(f"Illegal attempt to override frozen parameters {clashing_kwargs}.")
# merge the two dicts
kwargs = kwargs | self.frozen_kwargs

return self.func(*args, **kwargs)


def register_task(
id: str,
task_class: Type[AbstractBrowserTask],
task_kwargs: dict = None,
task_kwargs: dict = {},
default_task_kwargs: dict = {},
nondeterministic: bool = True,
*args,
**kwargs,
Expand All @@ -19,20 +43,32 @@ def register_task(
Args:
id: the id of the task to register (will be prepended by "browsergym/").
task_class: the task class to register.
task_kwargs: frozen task arguments (can not be overloaded at environment creation time).
task_kwargs_default: default task arguments (can be overloaded at environment creation time).
nondeterministic: whether the task cannot be guaranteed deterministic transitions.
*args: additional arguments for the browsergym environment.
*kwargs: additional arguments for the browsergym environment.
*args: additional sequential arguments for either the gym or the browsergym environment.
*kwargs: additional keyword arguments for either the gym or the browsergym environment.
"""
if task_kwargs and default_task_kwargs:
# check overlap between frozen and default task_kwargs
clashing_kwargs = set(task_kwargs) & set(default_task_kwargs) # key set intersection
if clashing_kwargs:
raise ValueError(
f"Illegal attempt to register Browsergym environment {id} with both frozen and default values for task parameters {clashing_kwargs}."
)

task_entrypoint = task_class

# freeze task_kwargs (cannot be overriden at environment creation)
task_entrypoint = frozen_partial(task_class, **task_kwargs)

# these environment arguments will be fixed, and error will be raised if they are set when calling gym.make()
fixed_env_kwargs = {}
if task_kwargs is not None:
fixed_env_kwargs["task_kwargs"] = task_kwargs
# pre-set default_task_kwargs (can be overriden at environment creation)
task_entrypoint = partial(task_entrypoint, **default_task_kwargs)

gym.register(
id=f"browsergym/{id}",
entry_point=lambda *env_args, **env_kwargs: BrowserEnv(
task_class, *env_args, **fixed_env_kwargs, **env_kwargs
task_entrypoint, *env_args, **env_kwargs
),
nondeterministic=nondeterministic,
*args,
Expand Down
14 changes: 7 additions & 7 deletions browsergym/core/src/browsergym/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,6 @@ def setup(self, page: playwright.sync_api.Page) -> tuple[str, dict]:
info: dict, custom information from the task.
"""

@abstractmethod
def teardown(self) -> None:
"""
Tear down the task and clean up any ressource / data created by the task.
"""

@abstractmethod
def validate(
self, page: playwright.sync_api.Page, chat_messages: list[str]
Expand All @@ -74,6 +67,13 @@ def cheat(self, page: playwright.sync_api.Page, chat_messages: list[str]) -> Non
"""
raise NotImplementedError

def teardown(self) -> None:
"""
Tear down the task and clean up any ressource / data created by the task (optional).
"""
pass


class OpenEndedTask(AbstractBrowserTask):
@classmethod
Expand Down
98 changes: 98 additions & 0 deletions tests/core/test_registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import re

import gymnasium as gym
import pytest

from browsergym.core.registration import register_task
from browsergym.core.task import AbstractBrowserTask


class TestTask(AbstractBrowserTask):
@classmethod
def get_task_id(cls):
raise NotImplementedError

def __init__(self, a: str = "", b: int = 0, c: bool = False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.a = a
self.b = b
self.c = c

def setup(self, page):
return "", {}

def teardown(self):
pass

def validate(self, page, chat_messages):
return 0, True, "", {}


register_task("test_task", TestTask)
register_task(
"test_task_with_defaults",
TestTask,
task_kwargs={"a": "new value"},
default_task_kwargs={"b": 1},
)


def test_registration():

with pytest.raises(ValueError):
register_task(
"test_task_forbidden",
TestTask,
task_kwargs={"a": "new value"},
default_task_kwargs={"a": "other value"},
)

env = gym.make("browsergym/test_task")

assert env.unwrapped.task_kwargs == {}

env.reset()
env.unwrapped.task.a == ""
env.unwrapped.task.b == 0
env.unwrapped.task.c == False
env.close()

env = gym.make("browsergym/test_task", task_kwargs={"a": "other", "b": 1})

assert env.unwrapped.task_kwargs == {"a": "other", "b": 1}

env.reset()
env.unwrapped.task.a == "other"
env.unwrapped.task.b == 1
env.unwrapped.task.c == False
env.close()

env = gym.make("browsergym/test_task_with_defaults")

assert env.unwrapped.task_kwargs == {}

env.reset()
env.unwrapped.task.a == "new value"
env.unwrapped.task.b == 1
env.unwrapped.task.c == False
env.close()

env = gym.make("browsergym/test_task_with_defaults", task_kwargs={"b": 2})

assert env.unwrapped.task_kwargs == {"b": 2}

env.reset()
env.unwrapped.task.a == "new value"
env.unwrapped.task.b == 2
env.unwrapped.task.c == False
env.close()

env = gym.make("browsergym/test_task_with_defaults", task_kwargs={"a": "other"})

assert env.unwrapped.task_kwargs == {"a": "other"}

with pytest.raises(
expected_exception=ValueError,
match=re.compile("Illegal attempt to override frozen parameters"),
):
env.reset()

0 comments on commit 49558be

Please sign in to comment.