Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor refactors #255

Merged
merged 2 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 44 additions & 8 deletions browsergym/core/src/browsergym/core/registration.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,38 @@
from functools import partial
from typing import Type

import gymnasium as gym

from .env import BrowserEnv
from .task import AbstractBrowserTask


class frozen_partial:
"""
Freeze some keyword arguments of a function.

"""

def __init__(self, func, **frozen_kwargs):
self.func = func
self.frozen_kwargs = frozen_kwargs

def __call__(self, *args, **kwargs):
# check overlap between kwargs and frozen_kwargs
clashing_kwargs = set(self.frozen_kwargs) & set(kwargs) # key set intersection
if clashing_kwargs:
raise ValueError(f"Illegal attempt to override frozen parameters {clashing_kwargs}.")
# merge the two dicts
kwargs = kwargs | self.frozen_kwargs

return self.func(*args, **kwargs)


def register_task(
id: str,
task_class: Type[AbstractBrowserTask],
task_kwargs: dict = None,
task_kwargs: dict = {},
default_task_kwargs: dict = {},
nondeterministic: bool = True,
*args,
**kwargs,
Expand All @@ -19,20 +43,32 @@ def register_task(
Args:
id: the id of the task to register (will be prepended by "browsergym/").
task_class: the task class to register.
task_kwargs: frozen task arguments (can not be overloaded at environment creation time).
task_kwargs_default: default task arguments (can be overloaded at environment creation time).
nondeterministic: whether the task cannot be guaranteed deterministic transitions.
*args: additional arguments for the browsergym environment.
*kwargs: additional arguments for the browsergym environment.
*args: additional sequential arguments for either the gym or the browsergym environment.
*kwargs: additional keyword arguments for either the gym or the browsergym environment.
"""
if task_kwargs and default_task_kwargs:
# check overlap between frozen and default task_kwargs
clashing_kwargs = set(task_kwargs) & set(default_task_kwargs) # key set intersection
if clashing_kwargs:
raise ValueError(
f"Illegal attempt to register Browsergym environment {id} with both frozen and default values for task parameters {clashing_kwargs}."
)

task_entrypoint = task_class

# freeze task_kwargs (cannot be overriden at environment creation)
task_entrypoint = frozen_partial(task_class, **task_kwargs)

# these environment arguments will be fixed, and error will be raised if they are set when calling gym.make()
fixed_env_kwargs = {}
if task_kwargs is not None:
fixed_env_kwargs["task_kwargs"] = task_kwargs
# pre-set default_task_kwargs (can be overriden at environment creation)
task_entrypoint = partial(task_entrypoint, **default_task_kwargs)

gym.register(
id=f"browsergym/{id}",
entry_point=lambda *env_args, **env_kwargs: BrowserEnv(
task_class, *env_args, **fixed_env_kwargs, **env_kwargs
task_entrypoint, *env_args, **env_kwargs
),
nondeterministic=nondeterministic,
*args,
Expand Down
14 changes: 7 additions & 7 deletions browsergym/core/src/browsergym/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,6 @@ def setup(self, page: playwright.sync_api.Page) -> tuple[str, dict]:
info: dict, custom information from the task.
"""

@abstractmethod
def teardown(self) -> None:
"""
Tear down the task and clean up any ressource / data created by the task.

"""

@abstractmethod
def validate(
self, page: playwright.sync_api.Page, chat_messages: list[str]
Expand All @@ -74,6 +67,13 @@ def cheat(self, page: playwright.sync_api.Page, chat_messages: list[str]) -> Non
"""
raise NotImplementedError

def teardown(self) -> None:
"""
Tear down the task and clean up any ressource / data created by the task (optional).

"""
pass


class OpenEndedTask(AbstractBrowserTask):
@classmethod
Expand Down
98 changes: 98 additions & 0 deletions tests/core/test_registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import re

import gymnasium as gym
import pytest

from browsergym.core.registration import register_task
from browsergym.core.task import AbstractBrowserTask


class TestTask(AbstractBrowserTask):
@classmethod
def get_task_id(cls):
raise NotImplementedError

def __init__(self, a: str = "", b: int = 0, c: bool = False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.a = a
self.b = b
self.c = c

def setup(self, page):
return "", {}

def teardown(self):
pass

def validate(self, page, chat_messages):
return 0, True, "", {}


register_task("test_task", TestTask)
register_task(
"test_task_with_defaults",
TestTask,
task_kwargs={"a": "new value"},
default_task_kwargs={"b": 1},
)


def test_registration():

with pytest.raises(ValueError):
register_task(
"test_task_forbidden",
TestTask,
task_kwargs={"a": "new value"},
default_task_kwargs={"a": "other value"},
)

env = gym.make("browsergym/test_task")

assert env.unwrapped.task_kwargs == {}

env.reset()
env.unwrapped.task.a == ""
env.unwrapped.task.b == 0
env.unwrapped.task.c == False
env.close()

env = gym.make("browsergym/test_task", task_kwargs={"a": "other", "b": 1})

assert env.unwrapped.task_kwargs == {"a": "other", "b": 1}

env.reset()
env.unwrapped.task.a == "other"
env.unwrapped.task.b == 1
env.unwrapped.task.c == False
env.close()

env = gym.make("browsergym/test_task_with_defaults")

assert env.unwrapped.task_kwargs == {}

env.reset()
env.unwrapped.task.a == "new value"
env.unwrapped.task.b == 1
env.unwrapped.task.c == False
env.close()

env = gym.make("browsergym/test_task_with_defaults", task_kwargs={"b": 2})

assert env.unwrapped.task_kwargs == {"b": 2}

env.reset()
env.unwrapped.task.a == "new value"
env.unwrapped.task.b == 2
env.unwrapped.task.c == False
env.close()

env = gym.make("browsergym/test_task_with_defaults", task_kwargs={"a": "other"})

assert env.unwrapped.task_kwargs == {"a": "other"}

with pytest.raises(
expected_exception=ValueError,
match=re.compile("Illegal attempt to override frozen parameters"),
):
env.reset()