Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Weights & Biases defaults to be overridden in training #294

Merged
merged 2 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,14 @@ accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r
--per_device_train_batch_size=1 --num_train_epochs=5
```

If you also wish to override the Weights and Biases default settings, you can do so as follows:

```shell
accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
--config recipes/Qwen2.5-1.5B-Instruct/sft/config_demo.yaml
--wandb_entity huggingface --wandb_project open-r1 --run_name Qwen2.5-1.5B-GRPO
```

> [!NOTE]
> The training commands below are configured for a node of 8 x H100s (80GB). For different hardware and topologies, you may need to tune the batch size and number of gradient accumulation steps.

Expand All @@ -141,10 +149,10 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con

### GRPO

To train via the GRPO trainer, we use one GPU to run vLLM for faster generation and the remaining GPUs for training. For example, one a node with 8 GPUs, use the `recipes/accelerate_configs/zero3.yaml` config and then overwrite `num_processes` to run on 7 devices:
To train via the GRPO trainer, we use one GPU to run vLLM for faster generation and the remaining GPUs for training. For example, one a node with 8 GPUs, use the `recipes/accelerate_configs/zero2.yaml` config and then overwrite `num_processes` to run on 7 devices:

```shell
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml \
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \
--num_processes=7 src/open_r1/grpo.py \
--config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo.yaml
```
Expand Down
16 changes: 16 additions & 0 deletions src/open_r1/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ class GRPOConfig(trl.GRPOConfig):
)
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
wandb_entity: Optional[str] = field(
default=None,
metadata={"help": ("The entity to store runs under.")},
)
wandb_project: Optional[str] = field(
default=None,
metadata={"help": ("The project to store runs under.")},
)


@dataclass
Expand All @@ -64,3 +72,11 @@ class SFTConfig(trl.SFTConfig):
)
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
wandb_entity: Optional[str] = field(
default=None,
metadata={"help": ("The entity to store runs under.")},
)
wandb_project: Optional[str] = field(
default=None,
metadata={"help": ("The project to store runs under.")},
)
6 changes: 5 additions & 1 deletion src/open_r1/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
reasoning_steps_reward,
)
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.logging import init_wandb_training
from trl import GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config


Expand Down Expand Up @@ -130,7 +131,7 @@ def main(script_args, training_args, model_args):
)
logger.info(f"Model parameters {model_args}")
logger.info(f"Script parameters {script_args}")
logger.info(f"Data parameters {training_args}")
logger.info(f"Training parameters {training_args}")

# Check for last checkpoint
last_checkpoint = None
Expand All @@ -139,6 +140,9 @@ def main(script_args, training_args, model_args):
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")

if "wandb" in training_args.report_to:
init_wandb_training(training_args)

# Load the dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

Expand Down
6 changes: 5 additions & 1 deletion src/open_r1/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@

from open_r1.configs import SFTConfig
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.logging import init_wandb_training
from trl import (
ModelConfig,
ScriptArguments,
Expand Down Expand Up @@ -88,7 +89,7 @@ def main(script_args, training_args, model_args):
)
logger.info(f"Model parameters {model_args}")
logger.info(f"Script parameters {script_args}")
logger.info(f"Data parameters {training_args}")
logger.info(f"Training parameters {training_args}")

# Check for last checkpoint
last_checkpoint = None
Expand All @@ -97,6 +98,9 @@ def main(script_args, training_args, model_args):
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")

if "wandb" in training_args.report_to:
init_wandb_training(training_args)

################
# Load datasets
################
Expand Down
9 changes: 9 additions & 0 deletions src/open_r1/utils/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import os


def init_wandb_training(training_args):
"""
Helper function for setting up Weights & Biases logging tools.
"""
os.environ["WANDB_ENTITY"] = training_args.wandb_entity
os.environ["WANDB_PROJECT"] = training_args.wandb_project