From 0d528d0c3e73ca74ffa7c544dce1fda124b4d706 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 4 Jan 2024 16:25:47 +0000 Subject: [PATCH 1/2] Fix batch all gather --- trl/trainer/ppo_trainer.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 3a7ceb434b..01e8ed14d5 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -201,6 +201,7 @@ def __init__( project_config=ProjectConfiguration(**config.project_kwargs), **config.accelerator_kwargs, ) + # torch.cuda.set_device(self.accelerator.local_process_index) # Step 1.1 Runtime variables filled by the accelerator config.world_size = self.accelerator.num_processes @@ -1320,10 +1321,26 @@ def log_stats( rewards (`List[torch.FloatTensor]`): A tensor of rewards. """ + + # all gather stats if not isinstance(rewards, torch.Tensor): rewards = torch.tensor(rewards).to(self.current_device) rewards = self.accelerator.gather(rewards).flatten() + if self.config.log_with == "wandb": + import wandb + + if any([column_to_log not in batch.keys() for column_to_log in columns_to_log]): + raise ValueError(f"Columns to log {columns_to_log} are not present in the batch {batch.keys()}.") + + batch_list = [batch[column_to_log] for column_to_log in columns_to_log] + if self.is_distributed: + gathered_batch_list = [] + for b in batch_list: + flattened = gather_object(b) + gathered_batch_list.append(flattened) + batch_list = gathered_batch_list + # Log only if we are in the main process if self.accelerator.is_main_process: logs = {} @@ -1336,20 +1353,6 @@ def log_stats( "'response'. " ) elif self.config.log_with == "wandb": - import wandb - - if any([column_to_log not in batch.keys() for column_to_log in columns_to_log]): - raise ValueError(f"Columns to log {columns_to_log} are not present in the batch {batch.keys()}.") - - batch_list = [batch[column_to_log] for column_to_log in columns_to_log] - if self.is_distributed: - self.accelerator.wait_for_everyone() - gathered_batch_list = [] - for batch in batch_list: - flattened = gather_object(batch) - gathered_batch_list.append(flattened) - batch_list = gathered_batch_list - table_rows = [list(r) for r in zip(*batch_list, rewards.cpu().tolist())] logs.update({"game_log": wandb.Table(columns=[*columns_to_log, "reward"], rows=table_rows)}) From 84c87502a8be6eb49286860ede38965e5a343208 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 4 Jan 2024 16:28:54 +0000 Subject: [PATCH 2/2] quick fix --- trl/trainer/ppo_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 01e8ed14d5..3194ff292c 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -201,7 +201,6 @@ def __init__( project_config=ProjectConfiguration(**config.project_kwargs), **config.accelerator_kwargs, ) - # torch.cuda.set_device(self.accelerator.local_process_index) # Step 1.1 Runtime variables filled by the accelerator config.world_size = self.accelerator.num_processes