From 097db8aad912710ea4c2e5a73da3ec2f4570d9ea Mon Sep 17 00:00:00 2001 From: Ian Barber Date: Mon, 18 Dec 2023 23:39:27 -0800 Subject: [PATCH] Minor fixes for runnability - Correct passing of configuration in train transformer and reward steps - Fix typo on prompt logger - Update log_scalar to match class spec. --- examples/rlhf/train.py | 2 +- examples/rlhf/train_reward.py | 2 +- examples/rlhf/utils.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/rlhf/train.py b/examples/rlhf/train.py index 2e554f3edb9..6d9e758503d 100644 --- a/examples/rlhf/train.py +++ b/examples/rlhf/train.py @@ -62,7 +62,7 @@ def main(cfg): dtype = cfg.sys.dtype compile_ = cfg.sys.compile - ctx = setup(device=device, dtype=dtype) + ctx = setup(cfg.sys) train_loader = get_dataloader( data_cfg.batch_size, diff --git a/examples/rlhf/train_reward.py b/examples/rlhf/train_reward.py index e16fbf45474..75d3ad86adc 100644 --- a/examples/rlhf/train_reward.py +++ b/examples/rlhf/train_reward.py @@ -65,7 +65,7 @@ def main(cfg): dtype = cfg.sys.dtype compile_ = cfg.sys.compile - ctx = setup(device=device, dtype=dtype) + ctx = setup(cfg.sys) train_loader = get_dataloader( data_cfg.batch_size, diff --git a/examples/rlhf/utils.py b/examples/rlhf/utils.py index 198b2e72bcb..d9a3ce42d48 100644 --- a/examples/rlhf/utils.py +++ b/examples/rlhf/utils.py @@ -130,7 +130,7 @@ def __init__( ): self.reward_estimator = reward_estimator self.model = model - self.promp_logger = prompt_logger + self.prompt_logger = prompt_logger self.io_cfg = io_cfg self.eval_interval = io_cfg.eval_interval self.log_interval = io_cfg.log_interval @@ -154,7 +154,7 @@ def maybe_evaluate(self): val_reward = self.reward_estimator(self.model, self.val_loader) self.prompt_logger.log(self.model) self.val_reward_logger.info(f"VALID: {self.it=}: {val_reward=:.4f}") - self.logger.log_scalar({"val_reward": val_reward}, step=self.it) + self.logger.log_scalar("val_reward", val_reward, step=self.it) # pbar.set_description(f"VALID: {it=}: {val_reward=:.4f}") if val_reward > self.best_val_reward: self.best_val_reward = val_reward