From a623fe0d464043859bd45501a5812cdae2059c80 Mon Sep 17 00:00:00 2001 From: Maxime Gasse Date: Thu, 24 Oct 2024 10:50:01 -0400 Subject: [PATCH] save_step_info bugfix (obs=None) --- .../src/browsergym/experiments/loop.py | 56 ++++++++++--------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/browsergym/experiments/src/browsergym/experiments/loop.py b/browsergym/experiments/src/browsergym/experiments/loop.py index 3d06fc6c..b89f0f41 100644 --- a/browsergym/experiments/src/browsergym/experiments/loop.py +++ b/browsergym/experiments/src/browsergym/experiments/loop.py @@ -438,40 +438,44 @@ def make_stats(self): def save_step_info(self, exp_dir, save_json=False, save_screenshot=True, save_som=False): - screenshot = self.obs.pop("screenshot", None) - screenshot_som = self.obs.pop("screenshot_som", None) - - if save_screenshot and screenshot is not None: - img = Image.fromarray(screenshot) - img.save(exp_dir / f"screenshot_step_{self.step}.png") - - if save_som and screenshot_som is not None: - img = Image.fromarray(screenshot_som) - img.save(exp_dir / f"screenshot_som_step_{self.step}.png") - - # save goal object (which might contain images) to a separate file to save space - if self.obs is not None and self.obs.get("goal_object", False): - # save the goal object only once (goal should never change once setup) - goal_object_file = Path(exp_dir) / "goal_object.pkl.gz" - if not goal_object_file.exists(): - with gzip.open(goal_object_file, "wb") as f: - pickle.dump(self.obs["goal_object"], f) - # set goal_object to a special placeholder value, which indicates it should be loaded from a separate file - self.obs["goal_object"] = None + # special treatment for some of the observation fields + if self.obs is not None: + # save screenshots to separate files + screenshot = self.obs.pop("screenshot", None) + screenshot_som = self.obs.pop("screenshot_som", None) + + if save_screenshot and screenshot is not None: + img = Image.fromarray(screenshot) + img.save(exp_dir / f"screenshot_step_{self.step}.png") + + if save_som and screenshot_som is not None: + img = Image.fromarray(screenshot_som) + img.save(exp_dir / f"screenshot_som_step_{self.step}.png") + + # save goal object (which might contain images) to a separate file to save space + if self.obs.get("goal_object", False): + # save the goal object only once (goal should never change once setup) + goal_object_file = Path(exp_dir) / "goal_object.pkl.gz" + if not goal_object_file.exists(): + with gzip.open(goal_object_file, "wb") as f: + pickle.dump(self.obs["goal_object"], f) + # set goal_object to a special placeholder value, which indicates it should be loaded from a separate file + self.obs["goal_object"] = None with gzip.open(exp_dir / f"step_{self.step}.pkl.gz", "wb") as f: - # TODO should we pop the screenshots too before this to save space ? pickle.dump(self, f) if save_json: with open(exp_dir / "steps_info.json", "w") as f: json.dump(self, f, indent=4, cls=DataclassJSONEncoder) - # add the screenshots back to the obs - if screenshot is not None: - self.obs["screenshot"] = screenshot - if screenshot_som is not None: - self.obs["screenshot_som"] = screenshot_som + if self.obs is not None: + # add the screenshots back to the obs + # why do we need this? + if screenshot is not None: + self.obs["screenshot"] = screenshot + if screenshot_som is not None: + self.obs["screenshot_som"] = screenshot_som def _extract_err_msg(episode_info: list[StepInfo]):