Skip to content

Commit

Permalink
agent simplification
Browse files Browse the repository at this point in the history
  • Loading branch information
gasse committed May 17, 2024
1 parent 001a74e commit 1793a8b
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 54 deletions.
16 changes: 11 additions & 5 deletions demo_agent/agents/basic/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,26 @@
from browsergym.experiments import Agent, AbstractAgentArgs
from browsergym.core.action.highlevel import HighLevelActionSet
from browsergym.core.action.python import PythonActionSet
from browsergym.utils.obs import flatten_axtree_to_str


class DemoAgent(Agent):
"""A basic agent using OpenAI API, to demonstrate BrowserGym's functionalities."""

@property
def action_mapping(self):
return self.action_space.to_python_code
def action_mapping(self, action: str):
return self.action_space.to_python_code(action)

def observation_mapping(self, obs: dict) -> dict:
return {
"goal": obs["goal"],
"axtree_txt": flatten_axtree_to_str(obs["axtree_object"]),
}

def __init__(self, model_name) -> None:
super().__init__()
self.model_name = model_name

action_space = HighLevelActionSet(
self.action_space = HighLevelActionSet(
subsets=["bid"], # define a subset of the action space
# subsets=["bid", "coord"] # allows the agent to also use x,y coordinates
strict=False, # less strict on the parsing of the actions
Expand All @@ -25,7 +31,7 @@ def __init__(self, model_name) -> None:
)

# uncomment this line to allow the agent to also use Python full python code
# action_space = PythonActionSet()
# self.action_space = PythonActionSet()

from openai import OpenAI

Expand Down
48 changes: 22 additions & 26 deletions demo_agent/agents/legacy/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,38 +30,34 @@ def make_agent(self):

class GenericAgent(Agent):

@property
def observation_mapping(self):
def observation_mapping(self, obs: dict) -> dict:
"""
Augment observations with text HTML and AXTree representations, which will be stored in
the experiment traces.
"""

def augmented_obs(obs):
obs = obs.copy()
obs["dom_txt"] = flatten_dom_to_str(
obs["dom_object"],
with_visible=self.flags.extract_visible_tag,
with_center_coords=self.flags.extract_coords == "center",
with_bounding_box_coords=self.flags.extract_coords == "box",
filter_visible_only=self.flags.extract_visible_elements_only,
)
obs["axtree_txt"] = flatten_axtree_to_str(
obs["axtree_object"],
with_visible=self.flags.extract_visible_tag,
with_center_coords=self.flags.extract_coords == "center",
with_bounding_box_coords=self.flags.extract_coords == "box",
filter_visible_only=self.flags.extract_visible_elements_only,
)
obs["pruned_html"] = prune_html(obs["dom_txt"])
return obs

return augmented_obs

@property
def action_mapping(self):
obs = obs.copy()
obs["dom_txt"] = flatten_dom_to_str(
obs["dom_object"],
with_visible=self.flags.extract_visible_tag,
with_center_coords=self.flags.extract_coords == "center",
with_bounding_box_coords=self.flags.extract_coords == "box",
filter_visible_only=self.flags.extract_visible_elements_only,
)
obs["axtree_txt"] = flatten_axtree_to_str(
obs["axtree_object"],
with_visible=self.flags.extract_visible_tag,
with_center_coords=self.flags.extract_coords == "center",
with_bounding_box_coords=self.flags.extract_coords == "box",
filter_visible_only=self.flags.extract_visible_elements_only,
)
obs["pruned_html"] = prune_html(obs["dom_txt"])

return obs

def action_mapping(self, action: str) -> str:
"""Use a BrowserGym AbstractActionSet mapping."""
return self.action_space.to_python_code
return self.action_space.to_python_code(action)

def __init__(
self,
Expand Down
40 changes: 18 additions & 22 deletions experiments/src/browsergym/experiments/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Optional
from typing import Any


class Agent(ABC):
Expand All @@ -8,51 +8,47 @@ class Agent(ABC):
with a browsergym environment.
"""

@property
@abstractmethod
def action_mapping(self) -> Optional[callable]:
def action_mapping(self, action: str) -> str:
"""
Returns a function that maps the actions returned by `get_action()` to BrowserGym-compatible Python code.
This property is meant to be overloaded by your agent (optional).
Maps the actions returned by `get_action()` to BrowserGym-compatible Python code.
Why this mapping? This mapping will happen within the BrowserGym environment, so that the experiment loop
manipulates and records pre-mapping actions and not the resulting Python code (which can be pretty verbose).
This property is meant to be overloaded by your agent.
Examples:
# no mapping, the agent directly produces Python code
return None
return action
# use a pre-defined action set of high-level function
action_set = browsergym.core.action.highlevel.HighLevelActionSet(subsets=["chat", "nav", "bid"])
return action_set.to_python_code
action_space = browsergym.core.action.highlevel.HighLevelActionSet(subsets=["chat", "nav", "bid"])
return action_space.to_python_code(action)
# use a pre-defined Python action set which extracts Markdown code snippets
action_set = browsergym.core.action.python.PythonActionSet()
return action_set.to_python_code
action_space = browsergym.core.action.python.PythonActionSet()
return action_space.to_python_code(action)
"""
return None
return action

@property
@abstractmethod
def observation_mapping(self) -> Optional[callable]:
def observation_mapping(self, obs: dict) -> Any:
"""
This method is meant to be overloaded by your agent.
Returns a function that pre-processes the observations before feeding them to `get_action()`.
Why this mapping? This mapping will happen within the experiment loop, so that the resulting observation gets
recorded in the execution traces.
This property is meant to be overloaded by your agent.
Examples:
from browsergym.utils.obs import flatten_axtree_to_str
return lambda obs: {
"axtree": flatten_axtree_to_str(obs["axtree_object"]),
return {
"goal": obs["goal"],
"axtree": flatten_axtree_to_str(obs["axtree_object"]),
}
"""
return None
return obs

@abstractmethod
def get_action(self, obs) -> tuple[str, dict]:
def get_action(self, obs: Any) -> tuple[str, dict]:
"""
Updates the agent with the current observation, and returns its next action (plus an info dict, optional).
Expand Down
3 changes: 2 additions & 1 deletion experiments/src/browsergym/experiments/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ def from_step(self, env: gym.Env, action: str):

def from_action(self, agent: Agent):
self.profiling.agent_start = time.time()
self.obs = agent.observation_mapping(self.obs)
if agent.observation_mapping:
self.obs = agent.observation_mapping(self.obs)
self.action, self.agent_info = agent.get_action(self.obs)
self.profiling.agent_stop = time.time()

Expand Down
7 changes: 7 additions & 0 deletions experiments/tests/test_exp_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,17 @@
from browsergym.core.action.highlevel import HighLevelActionSet
from browsergym.experiments.agent import Agent
from browsergym.experiments.loop import AbstractAgentArgs, EnvArgs, ExpArgs, get_exp_result
from browsergym.utils.obs import flatten_axtree_to_str


class MiniwobTestAgent(Agent):

def observation_mapping(self, obs: dict):
return {"axtree_txt": flatten_axtree_to_str(obs["axtree_object"])}

def action_mapping(self, action: str):
return self.action_space.to_python_code(action)

def __init__(self):
self.action_space = HighLevelActionSet(subsets="bid")

Expand Down

0 comments on commit 1793a8b

Please sign in to comment.