Skip to content

Commit

Permalink
Logging debug (#47)
Browse files Browse the repository at this point in the history
* put reward to 0 as default

* add some logging

* Makefile update
  • Loading branch information
recursix authored May 27, 2024
1 parent 37d8dba commit bf9f4d3
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 25 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
install:
@echo "--- 🚀 Installing project dependencies ---"
pip install -e ./core ./miniwob ./webarena ./experiments ./
pip install -e ./core -e ./miniwob -e ./webarena -e ./experiments -e .

install-demo:
@echo "--- 🚀 Installing demo dependencies ---"
Expand Down
6 changes: 4 additions & 2 deletions core/src/browsergym/core/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

CHATBOX_DIR = resources.files(chat_files)

logger = logging.getLogger(__name__)


class Chat:
def __init__(
Expand Down Expand Up @@ -61,12 +63,12 @@ def add_message(self, role: Literal["user", "assistant", "info", "infeasible"],
self.page.evaluate(f"addChatMessage({repr(role)}, {repr(timestamp)}, {repr(msg)});")

def wait_for_user_message(self):
logging.info("Waiting for message from user...")
logger.info("Waiting for message from user...")
# reset flag
self.page.evaluate("USER_MESSAGE_RECEIVED = false;")
# wait for flag to be raised
self.page.wait_for_function("USER_MESSAGE_RECEIVED", polling=100, timeout=0)
logging.info("Message received.")
logger.info("Message received.")

def close(self):
self.context.close()
Expand Down
21 changes: 16 additions & 5 deletions core/src/browsergym/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
from . import _get_global_playwright


logger = logging.getLogger(__name__)


class BrowserEnv(gym.Env, ABC):
"""The main BrowserGym class, which encapsulates instruction-following Web browsing into a Gymnasium environment."""

Expand Down Expand Up @@ -169,7 +172,7 @@ def override_property(task, env, property):
if env_value is None:
return task_value
else:
logging.warning(
logger.warning(
f"Overriding the task's {property} parameter ({repr(task_value)} => {repr(env_value)}). This might change the task's behaviour and difficulty."
)
return env_value
Expand Down Expand Up @@ -300,6 +303,7 @@ def override_property(task, env, property):
return obs, info

def step(self, action: str) -> tuple:

self.last_action = action

info = {}
Expand All @@ -314,6 +318,7 @@ def report_infeasible_instructions(reason: str):
self.infeasible_message_received = True

# try to execute the action
logger.debug(f"Executing action")
try:
if self.action_mapping:
code = self.action_mapping(action)
Expand All @@ -331,7 +336,7 @@ def report_infeasible_instructions(reason: str):
match = re.match("TimeoutError: Timeout ([0-9]+)ms exceeded.", self.last_action_error)
if match:
info["action_exec_timeout"] = float(match.groups()[0]) / 1000 # ms to sec

logger.debug(f"Action executed")
info["action_exec_stop"] = time.time()

# wait a bit (for the JavaScript callback to set the active page)
Expand All @@ -344,20 +349,25 @@ def report_infeasible_instructions(reason: str):
# after the action is executed, the active page might have changed
# perform a safety check
self._active_page_check()
logger.debug(f"Active page checked")

# if asked, wait for user message
self._wait_for_user_message()
logger.debug(f"User message done")

logger.debug(f"Initiating task validation")
# extract reward, done, user_message, info (task-specific)
reward, done, user_message, task_info = self._task_validate()
info["task_info"] = task_info
logger.debug(f"Task validation done")

# add any user message sent by the task to the chat
if user_message:
self.chat.add_message(role="user", msg=user_message)

# extract observation (generic)
obs = self._get_obs()
logger.debug(f"Observation extracted")

# new step API wants a 5-tuple (gymnasium)
terminated = done or (
Expand All @@ -377,7 +387,7 @@ def _task_validate(self):

# safety fix, in case validate() did mess up the active page and/or page history
if prev_active_page != self.page or prev_page_history != self.page_history:
logging.info(
logger.info(
"The active page and / or page history has changed during task.validate(). A recovery fix will be applied."
)
self.page = prev_active_page
Expand All @@ -404,6 +414,7 @@ def _wait_dom_loaded(self):
pass

def _activate_page_from_js(self, page: playwright.sync_api.Page):
logger.debug(f"_activate_page_from_js(page) called, page={str(page)}")
if not page.context == self.context:
raise RuntimeError(
f"Unexpected: activating a page that belongs to a different browser context ({page})."
Expand All @@ -423,7 +434,7 @@ def _active_page_check(self):
# make sure there is always a page open
# if all pages have been closed, create a new page
if len(self.context.pages) == 0:
logging.warning(f"All pages are closed, opening a new page.")
logger.warning(f"All pages are closed, opening a new page.")
self.page = self.context.new_page()

# if the active page got closed, get the last active page from the history
Expand Down Expand Up @@ -464,7 +475,7 @@ def _get_obs(self):
or "Frame has been detached" in err_msg
or "Cannot mark a child frame without a bid" in err_msg
):
logging.warning(
logger.warning(
f"An error occured while extracting the dom and axtree. Retrying ({retries_left}/{EXTRACT_OBS_MAX_TRIES} tries left).\n{repr(e)}"
)
# post-extract cleanup (aria-roledescription attribute)
Expand Down
15 changes: 9 additions & 6 deletions core/src/browsergym/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
MARK_FRAMES_MAX_TRIES = 3


logger = logging.getLogger(__name__)


class MarkingError(Exception):
pass

Expand All @@ -38,7 +41,7 @@ def mark_frames_recursive(frame, frame_bid: str):
)
# print warning messages if any
for msg in warning_msgs:
logging.warning(msg)
logger.warning(msg)

# recursively mark all descendant frames
for child_frame in frame.child_frames:
Expand All @@ -48,7 +51,7 @@ def mark_frames_recursive(frame, frame_bid: str):
# deal with weird frames (pdf viewer in <embed>)
child_frame_elem = child_frame.frame_element()
if not child_frame_elem.content_frame() == child_frame:
logging.warning(
logger.warning(
f"Skipping frame '{child_frame.name}' for marking, seems problematic."
)
continue
Expand Down Expand Up @@ -76,7 +79,7 @@ def _post_extract(page: playwright.sync_api.Page):
if not frame == page.main_frame:
# deal with weird frames (pdf viewer in <embed>)
if not frame.frame_element().content_frame() == frame:
logging.warning(f"Skipping frame '{frame.name}' for unmarking, seems problematic.")
logger.warning(f"Skipping frame '{frame.name}' for unmarking, seems problematic.")
continue
# deal with sandboxed frames with blocked script execution
sandbox_attr = frame.frame_element().get_attribute("sandbox")
Expand Down Expand Up @@ -142,7 +145,7 @@ def extract_data_items_from_aria(string):

match = __DATA_REGEXP.fullmatch(string)
if not match:
logging.warning(
logger.warning(
f'Data items could not be extracted from "aria-roledescription" attribute: {string}'
)
return [], string
Expand Down Expand Up @@ -379,7 +382,7 @@ def to_string(idx):
bid = node["bid"]
if bid:
if bid in extra_properties:
logging.warning(f"duplicate {BID_ATTR}={repr(bid)} attribute detected")
logger.warning(f"duplicate {BID_ATTR}={repr(bid)} attribute detected")
extra_properties[bid] = {
extra_prop: node[extra_prop]
for extra_prop in ("visibility", "bbox", "clickable", "set_of_marks")
Expand Down Expand Up @@ -490,7 +493,7 @@ def extract_merged_axtree(page: playwright.sync_api.Page):
assert frame_root_node["frameId"] == frame_id
node["childIds"].append(frame_root_node["nodeId"])
else:
logging.warning(f"Extracted AXTree does not contain frameId '{frame_id}'")
logger.warning(f"Extracted AXTree does not contain frameId '{frame_id}'")

cdp.detach()

Expand Down
3 changes: 0 additions & 3 deletions demo_agent/agents/legacy/utils/chat_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
from transformers import GPT2TokenizerFast


logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")


@dataclass
class ChatModelArgs:
"""Serializable object for instantiating a generic chat model.
Expand Down
66 changes: 59 additions & 7 deletions experiments/src/browsergym/experiments/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from .agent import Agent
from .utils import count_messages_token, count_tokens

logger = logging.getLogger(__name__)


@dataclass
class EnvArgs:
Expand Down Expand Up @@ -112,7 +114,8 @@ class ExpArgs:
enable_debug: bool = True
err_msg: str = None
stack_trace: str = None
order: int = None # use to keep the original order the experiments were meant to be lancuhed.
order: int = None # use to keep the original order the experiments were meant to be launched.
logging_level: int = logging.INFO

def prepare(self, exp_root):
"""Prepare the experiment directory and save the experiment arguments.
Expand Down Expand Up @@ -148,30 +151,43 @@ def prepare(self, exp_root):
def run(self):
"""Run the experiment and save the results"""

# start writing logs to run logfile
self._set_logger()

episode_info = []
try:
logging.info(f"Running experiment {self.exp_name} in:\n {self.exp_dir}")
logger.info(f"Running experiment {self.exp_name} in:\n {self.exp_dir}")
agent = self.agent_args.make_agent()
logger.debug(f"Agent created.")
env = self.env_args.make_env(
action_mapping=agent.action_set.to_python_code, exp_dir=self.exp_dir
)
logger.debug(f"Environment created.")

err_msg, stack_trace = None, None
step_info = StepInfo(step=0)
episode_info = [step_info]
step_info.from_reset(env, seed=self.env_args.task_seed)
logger.debug(f"Environment reset.")

while not step_info.is_done: # set a limit
logger.debug(f"Starting step {step_info.step}.")
action = step_info.from_action(agent)
logger.debug(f"Agent chose action:\n {action}")

step_info.save_step_info(self.exp_dir)
logger.debug(f"Step info saved.")
if action is None:
break

_send_chat_info(env.unwrapped.chat, action, step_info.agent_info)
logger.debug(f"Chat info sent.")

step_info = StepInfo(step=step_info.step + 1)
episode_info.append(step_info)
logger.debug(f"Sending action to environment.")
step_info.from_step(env, action)
logger.debug(f"Environment stepped.")

except Exception as e:
err_msg = f"Exception uncaught by agent or environment in task {self.env_args.task_name}.\n{type(e).__name__}:\n{e}"
Expand All @@ -180,17 +196,45 @@ def run(self):
self.err_msg = err_msg
self.stack_trace = stack_trace

logging.warning(err_msg + "\n" + stack_trace)
logger.warning(err_msg + "\n" + stack_trace)
if _is_debugging() and self.enable_debug:
raise

finally:
try:
step_info.save_step_info(self.exp_dir)
except Exception as e:
logger.error(f"Error while saving step info in the finally block: {e}")
try:
_save_summary_info(episode_info, self.exp_dir, err_msg, stack_trace)
except Exception as e:
logger.error(f"Error while saving summary info in the finally block: {e}")
try:
env.close()
except Exception as e:
logging.error(f"Error while finalizing the experiment loop: {e}")
logger.error(f"Error while closing the environment in the finally block: {e}")
# stop writing logs to run logfile
self._unset_logger()

def _set_logger(self):
# output logging traces to a log file
file_handler = logging.FileHandler(self.exp_dir / "experiment.log")
file_handler.setLevel(self.logging_level) # same level as console outputs
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
file_handler.setFormatter(formatter)
# setup root logger
root_logger = logging.getLogger()
root_logger.setLevel(self.logging_level)
root_logger.addHandler(file_handler)
# setup openai logger (don't go below INFO verbosity)
openai_logger = logging.getLogger("openai._base_client")
openai_logger.setLevel(max(logging.INFO, self.logging_level))

self.logging_file_handler = file_handler

def _unset_logger(self):
root_logger = logging.getLogger()
root_logger.removeHandler(self.logging_file_handler)


@dataclass
Expand Down Expand Up @@ -236,8 +280,8 @@ class StepInfo:

step: int = None
obs: dict = None
reward: float = None
raw_reward: float = None
reward: float = 0
raw_reward: float = 0
terminated: bool = None
truncated: bool = None
action: str = None
Expand Down Expand Up @@ -428,6 +472,7 @@ def __init__(self, exp_dir) -> None:
self._summary_info = None
self._screenshots = {}
self._flat_exp_args = None
self._logs = None

@property
def exp_args(self):
Expand Down Expand Up @@ -518,6 +563,13 @@ def task_video_path(self) -> Path:
def combined_video_path(self) -> Path:
return self.exp_dir / "combined_video.mp4"

@property
def logs(self):
if self._logs is None:
with open(self.exp_dir / "experiment.log", "r") as f:
self._logs = f.read()
return self._logs


EXP_RESULT_CACHE = {}

Expand Down Expand Up @@ -608,7 +660,7 @@ def _send_chat_info(chat: Chat, action: str, agent_info: dict):
{action}
"""

logging.info(msg)
logger.info(msg)
chat.add_message(role="info", msg=msg)


Expand Down
4 changes: 3 additions & 1 deletion webarena/src/browsergym/webarena/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from .instance import WebArenaInstance

logger = logging.getLogger(__name__)


class GenericWebArenaTask(AbstractBrowserTask):
"""
Expand Down Expand Up @@ -176,7 +178,7 @@ def validate(
)
# llm_fuzzy_match() bugfix (assert "correct" in response)
except AssertionError as e:
logging.info(
logger.info(
"llm_fuzzy_match() bugfix applied: AssertionError in evaluator, using score = 0.0"
)
score = 0.0
Expand Down

0 comments on commit bf9f4d3

Please sign in to comment.