From 177e6c4ee27ed7612dd24fc4b656cab5f2a99053 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Wed, 12 Feb 2025 11:56:43 +0000 Subject: [PATCH 1/2] Enable WandB defaults to be set --- README.md | 10 +++++++++- src/open_r1/configs.py | 16 ++++++++++++++++ src/open_r1/grpo.py | 6 +++++- src/open_r1/sft.py | 6 +++++- src/open_r1/utils/logging.py | 9 +++++++++ 5 files changed, 44 insertions(+), 3 deletions(-) create mode 100644 src/open_r1/utils/logging.py diff --git a/README.md b/README.md index 48b767b5..2579d776 100644 --- a/README.md +++ b/README.md @@ -126,6 +126,14 @@ accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r --per_device_train_batch_size=1 --num_train_epochs=5 ``` +If you also wish to override the Weights and Biases default settings, you can do so as follows: + +```shell +accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \ + --config recipes/Qwen2.5-1.5B-Instruct/sft/config_demo.yaml + --wandb_entity huggingface --wandb_project open-r1 --run_name Qwen2.5-1.5B-GRPO +``` + > [!NOTE] > The training commands below are configured for a node of 8 x H100s (80GB). For different hardware and topologies, you may need to tune the batch size and number of gradient accumulation steps. @@ -144,7 +152,7 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con To train via the GRPO trainer, we use one GPU to run vLLM for faster generation and the remaining GPUs for training. For example, one a node with 8 GPUs, use the `recipes/accelerate_configs/zero3.yaml` config and then overwrite `num_processes` to run on 7 devices: ```shell -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml \ +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \ --num_processes=7 src/open_r1/grpo.py \ --config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo.yaml ``` diff --git a/src/open_r1/configs.py b/src/open_r1/configs.py index 57968b4b..3a6f6866 100644 --- a/src/open_r1/configs.py +++ b/src/open_r1/configs.py @@ -40,6 +40,14 @@ class GRPOConfig(trl.GRPOConfig): ) overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."}) push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."}) + wandb_entity: Optional[str] = field( + default=None, + metadata={"help": ("The entity to store runs under.")}, + ) + wandb_project: Optional[str] = field( + default=None, + metadata={"help": ("The project to store runs under.")}, + ) @dataclass @@ -64,3 +72,11 @@ class SFTConfig(trl.SFTConfig): ) overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."}) push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."}) + wandb_entity: Optional[str] = field( + default=None, + metadata={"help": ("The entity to store runs under.")}, + ) + wandb_project: Optional[str] = field( + default=None, + metadata={"help": ("The project to store runs under.")}, + ) diff --git a/src/open_r1/grpo.py b/src/open_r1/grpo.py index 5cb64552..128375db 100644 --- a/src/open_r1/grpo.py +++ b/src/open_r1/grpo.py @@ -33,6 +33,7 @@ reasoning_steps_reward, ) from open_r1.utils.callbacks import get_callbacks +from open_r1.utils.logging import init_wandb_training from trl import GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config @@ -130,7 +131,7 @@ def main(script_args, training_args, model_args): ) logger.info(f"Model parameters {model_args}") logger.info(f"Script parameters {script_args}") - logger.info(f"Data parameters {training_args}") + logger.info(f"Training parameters {training_args}") # Check for last checkpoint last_checkpoint = None @@ -139,6 +140,9 @@ def main(script_args, training_args, model_args): if last_checkpoint is not None and training_args.resume_from_checkpoint is None: logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.") + if "wandb" in training_args.report_to: + init_wandb_training(training_args) + # Load the dataset dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) diff --git a/src/open_r1/sft.py b/src/open_r1/sft.py index e8587d03..16791cd4 100644 --- a/src/open_r1/sft.py +++ b/src/open_r1/sft.py @@ -48,6 +48,7 @@ from open_r1.configs import SFTConfig from open_r1.utils.callbacks import get_callbacks +from open_r1.utils.logging import init_wandb_training from trl import ( ModelConfig, ScriptArguments, @@ -88,7 +89,7 @@ def main(script_args, training_args, model_args): ) logger.info(f"Model parameters {model_args}") logger.info(f"Script parameters {script_args}") - logger.info(f"Data parameters {training_args}") + logger.info(f"Training parameters {training_args}") # Check for last checkpoint last_checkpoint = None @@ -97,6 +98,9 @@ def main(script_args, training_args, model_args): if last_checkpoint is not None and training_args.resume_from_checkpoint is None: logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.") + if "wandb" in training_args.report_to: + init_wandb_training(training_args) + ################ # Load datasets ################ diff --git a/src/open_r1/utils/logging.py b/src/open_r1/utils/logging.py new file mode 100644 index 00000000..764f30f8 --- /dev/null +++ b/src/open_r1/utils/logging.py @@ -0,0 +1,9 @@ +import os + + +def init_wandb_training(training_args): + """ + Helper function for setting up Weights & Biases logging tools. + """ + os.environ["WANDB_ENTITY"] = training_args.wandb_entity + os.environ["WANDB_PROJECT"] = training_args.wandb_project From a82bb85a439da66b58bf0a4ef1b8a1dbb82eb391 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Wed, 12 Feb 2025 11:57:59 +0000 Subject: [PATCH 2/2] Fix --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 2579d776..333c54f8 100644 --- a/README.md +++ b/README.md @@ -149,7 +149,7 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con ### GRPO -To train via the GRPO trainer, we use one GPU to run vLLM for faster generation and the remaining GPUs for training. For example, one a node with 8 GPUs, use the `recipes/accelerate_configs/zero3.yaml` config and then overwrite `num_processes` to run on 7 devices: +To train via the GRPO trainer, we use one GPU to run vLLM for faster generation and the remaining GPUs for training. For example, one a node with 8 GPUs, use the `recipes/accelerate_configs/zero2.yaml` config and then overwrite `num_processes` to run on 7 devices: ```shell ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \