diff --git a/demo_agent/agents/basic/agent.py b/demo_agent/agents/basic/agent.py index 36d896a8..291a0840 100644 --- a/demo_agent/agents/basic/agent.py +++ b/demo_agent/agents/basic/agent.py @@ -3,20 +3,26 @@ from browsergym.experiments import Agent, AbstractAgentArgs from browsergym.core.action.highlevel import HighLevelActionSet from browsergym.core.action.python import PythonActionSet +from browsergym.utils.obs import flatten_axtree_to_str class DemoAgent(Agent): """A basic agent using OpenAI API, to demonstrate BrowserGym's functionalities.""" - @property - def action_mapping(self): - return self.action_space.to_python_code + def action_mapping(self, action: str): + return self.action_space.to_python_code(action) + + def observation_mapping(self, obs: dict) -> dict: + return { + "goal": obs["goal"], + "axtree_txt": flatten_axtree_to_str(obs["axtree_object"]), + } def __init__(self, model_name) -> None: super().__init__() self.model_name = model_name - action_space = HighLevelActionSet( + self.action_space = HighLevelActionSet( subsets=["bid"], # define a subset of the action space # subsets=["bid", "coord"] # allows the agent to also use x,y coordinates strict=False, # less strict on the parsing of the actions @@ -25,7 +31,7 @@ def __init__(self, model_name) -> None: ) # uncomment this line to allow the agent to also use Python full python code - # action_space = PythonActionSet() + # self.action_space = PythonActionSet() from openai import OpenAI diff --git a/demo_agent/agents/legacy/agent.py b/demo_agent/agents/legacy/agent.py index bc76ffb1..95e9e420 100644 --- a/demo_agent/agents/legacy/agent.py +++ b/demo_agent/agents/legacy/agent.py @@ -30,38 +30,34 @@ def make_agent(self): class GenericAgent(Agent): - @property - def observation_mapping(self): + def observation_mapping(self, obs: dict) -> dict: """ Augment observations with text HTML and AXTree representations, which will be stored in the experiment traces. """ - def augmented_obs(obs): - obs = obs.copy() - obs["dom_txt"] = flatten_dom_to_str( - obs["dom_object"], - with_visible=self.flags.extract_visible_tag, - with_center_coords=self.flags.extract_coords == "center", - with_bounding_box_coords=self.flags.extract_coords == "box", - filter_visible_only=self.flags.extract_visible_elements_only, - ) - obs["axtree_txt"] = flatten_axtree_to_str( - obs["axtree_object"], - with_visible=self.flags.extract_visible_tag, - with_center_coords=self.flags.extract_coords == "center", - with_bounding_box_coords=self.flags.extract_coords == "box", - filter_visible_only=self.flags.extract_visible_elements_only, - ) - obs["pruned_html"] = prune_html(obs["dom_txt"]) - return obs - - return augmented_obs - - @property - def action_mapping(self): + obs = obs.copy() + obs["dom_txt"] = flatten_dom_to_str( + obs["dom_object"], + with_visible=self.flags.extract_visible_tag, + with_center_coords=self.flags.extract_coords == "center", + with_bounding_box_coords=self.flags.extract_coords == "box", + filter_visible_only=self.flags.extract_visible_elements_only, + ) + obs["axtree_txt"] = flatten_axtree_to_str( + obs["axtree_object"], + with_visible=self.flags.extract_visible_tag, + with_center_coords=self.flags.extract_coords == "center", + with_bounding_box_coords=self.flags.extract_coords == "box", + filter_visible_only=self.flags.extract_visible_elements_only, + ) + obs["pruned_html"] = prune_html(obs["dom_txt"]) + + return obs + + def action_mapping(self, action: str) -> str: """Use a BrowserGym AbstractActionSet mapping.""" - return self.action_space.to_python_code + return self.action_space.to_python_code(action) def __init__( self, diff --git a/experiments/src/browsergym/experiments/agent.py b/experiments/src/browsergym/experiments/agent.py index e0c47327..9f3bec27 100644 --- a/experiments/src/browsergym/experiments/agent.py +++ b/experiments/src/browsergym/experiments/agent.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Optional +from typing import Any class Agent(ABC): @@ -8,51 +8,47 @@ class Agent(ABC): with a browsergym environment. """ - @property - @abstractmethod - def action_mapping(self) -> Optional[callable]: + def action_mapping(self, action: str) -> str: """ - Returns a function that maps the actions returned by `get_action()` to BrowserGym-compatible Python code. + This property is meant to be overloaded by your agent (optional). + + Maps the actions returned by `get_action()` to BrowserGym-compatible Python code. Why this mapping? This mapping will happen within the BrowserGym environment, so that the experiment loop manipulates and records pre-mapping actions and not the resulting Python code (which can be pretty verbose). - This property is meant to be overloaded by your agent. - Examples: # no mapping, the agent directly produces Python code - return None + return action # use a pre-defined action set of high-level function - action_set = browsergym.core.action.highlevel.HighLevelActionSet(subsets=["chat", "nav", "bid"]) - return action_set.to_python_code + action_space = browsergym.core.action.highlevel.HighLevelActionSet(subsets=["chat", "nav", "bid"]) + return action_space.to_python_code(action) # use a pre-defined Python action set which extracts Markdown code snippets - action_set = browsergym.core.action.python.PythonActionSet() - return action_set.to_python_code + action_space = browsergym.core.action.python.PythonActionSet() + return action_space.to_python_code(action) """ - return None + return action - @property - @abstractmethod - def observation_mapping(self) -> Optional[callable]: + def observation_mapping(self, obs: dict) -> Any: """ + This method is meant to be overloaded by your agent. + Returns a function that pre-processes the observations before feeding them to `get_action()`. Why this mapping? This mapping will happen within the experiment loop, so that the resulting observation gets recorded in the execution traces. - This property is meant to be overloaded by your agent. - Examples: from browsergym.utils.obs import flatten_axtree_to_str - return lambda obs: { - "axtree": flatten_axtree_to_str(obs["axtree_object"]), + return { "goal": obs["goal"], + "axtree": flatten_axtree_to_str(obs["axtree_object"]), } """ - return None + return obs @abstractmethod - def get_action(self, obs) -> tuple[str, dict]: + def get_action(self, obs: Any) -> tuple[str, dict]: """ Updates the agent with the current observation, and returns its next action (plus an info dict, optional). diff --git a/experiments/src/browsergym/experiments/loop.py b/experiments/src/browsergym/experiments/loop.py index 8e04927b..22846919 100644 --- a/experiments/src/browsergym/experiments/loop.py +++ b/experiments/src/browsergym/experiments/loop.py @@ -257,7 +257,8 @@ def from_step(self, env: gym.Env, action: str): def from_action(self, agent: Agent): self.profiling.agent_start = time.time() - self.obs = agent.observation_mapping(self.obs) + if agent.observation_mapping: + self.obs = agent.observation_mapping(self.obs) self.action, self.agent_info = agent.get_action(self.obs) self.profiling.agent_stop = time.time() diff --git a/experiments/tests/test_exp_loop.py b/experiments/tests/test_exp_loop.py index b0c37423..1bd4e8d0 100644 --- a/experiments/tests/test_exp_loop.py +++ b/experiments/tests/test_exp_loop.py @@ -6,10 +6,17 @@ from browsergym.core.action.highlevel import HighLevelActionSet from browsergym.experiments.agent import Agent from browsergym.experiments.loop import AbstractAgentArgs, EnvArgs, ExpArgs, get_exp_result +from browsergym.utils.obs import flatten_axtree_to_str class MiniwobTestAgent(Agent): + def observation_mapping(self, obs: dict): + return {"axtree_txt": flatten_axtree_to_str(obs["axtree_object"])} + + def action_mapping(self, action: str): + return self.action_space.to_python_code(action) + def __init__(self): self.action_space = HighLevelActionSet(subsets="bid")