-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmodal_tree_grpo.py
118 lines (103 loc) · 3.3 KB
/
modal_tree_grpo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import modal
import sys
import traceback
# Define CUDA specifications
cuda_version = "12.4.0"
flavor = "devel"
operating_sys = "ubuntu22.04"
tag = f"{cuda_version}-{flavor}-{operating_sys}"
# Create Modal image with all necessary dependencies
image = (
modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.11")
.apt_install("git")
.pip_install("torch")
.pip_install("transformers")
.pip_install("accelerate")
.pip_install("datasets")
.pip_install("wandb")
.pip_install("trl>=0.7.6")
.pip_install("huggingface_hub")
.pip_install("bitsandbytes")
)
with image.imports():
from reinforcement_learning.tree_grpo import GRPOConfig, GRPOTrainer
# Create Modal app
app = modal.App("train-policy-tree-grpo", image=image)
@app.function(
cpu=4.0,
# gpu=modal.gpu.H100(count=1),
gpu=modal.gpu.A10G(),
timeout=24 * 60 * 60,
# memory=32768,
secrets=[
modal.Secret.from_name("hf-token"),
modal.Secret.from_name("wandb-token")
],
)
def train_policy_grpo():
import os
from huggingface_hub import HfFolder
from datasets import load_dataset
import wandb
try:
# Set up HuggingFace token
hf_token = os.environ["HF_TOKEN"]
HfFolder.save_token(hf_token)
# Set up Weights & Biases
wandb.login(key=os.environ["WANDB_API_KEY"])
# Configuration
config = GRPOConfig(
exp_name="math_improvement",
reward_model_path="rawsh/MetaMath-Qwen2.5-0.5b-PRM",
num_grpo_epochs=4,
sampling_group_size=8,
sampling_strategy="top_p",
sampling_temperature=0.7,
# learning_rate=1e-5,
# num_train_epochs=3,
# per_device_train_batch_size=4,
# gradient_accumulation_steps=4,
# output_dir="./grpo_math_model",
# report_to=["wandb"]
)
# Initialize wandb
wandb.init(
project="grpo_math",
name=config.exp_name,
config=vars(config)
)
# Load dataset
train_dataset = load_dataset("lighteval/MATH", "all", split="train")
eval_dataset = load_dataset("lighteval/MATH", "all", split="test")
# Create trainer
trainer = GRPOTrainer.from_pretrained(
config=config,
pretrained_model_name_or_path="rawsh/MetaMath-Qwen2.5-0.5b",
train_dataset=train_dataset,
eval_dataset=eval_dataset
)
# Train
trainer.train()
# Save final model
trainer.save_model()
# Close wandb
wandb.finish()
except Exception as e:
print(f"Error during training: {str(e)}", file=sys.stderr)
print("Traceback:", file=sys.stderr)
traceback.print_exc(file=sys.stderr)
# Make sure to finish wandb run even on error
try:
wandb.finish()
except:
pass
raise e
@app.local_entrypoint()
def main():
print("Starting full model GRPO training on Modal...")
try:
train_policy_grpo.remote()
print("Training job submitted to Modal. Check W&B dashboard for training progress.")
except Exception as e:
print(f"Error in training job: {str(e)}")
sys.exit(1)