Skip to content

add mini r1 zero tutorial #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions model-examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,15 @@
### AIGC/文生音频

- (TODO) MusicGen 介绍与推理实现


### LLMs/大模型

- (TODO) DeepSeek V3 关键技术介绍


### Reasoning/推理

- [目录](./reasoning-models/README.md)

- [Mini DeepSeek R1 Zero 从零到一实践](./reasoning-models/mini-r1-zero/README.md)
14 changes: 14 additions & 0 deletions model-examples/reasoning-models/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
## Reasoning Tutorials

### Updating

- [mini r1 zero](./mini-r1-zero/README.md): A minimal reproduction of [DeepSeek R1 Zero](https://github.com/deepseek-ai/DeepSeek-R1) with native MindSpore.

### TODO List

- (TODO) open-r1: a reproduction of [huggingface/open-r1](https://github.com/huggingface/open-r1)
- (TODO) simpleRL-reason: a reproduction of [hkust-nlp/simpleRL-reason](https://github.com/hkust-nlp/simpleRL-reason)
- (TODO) search-and-learn: a reproduction of [huggingface/search-and-learn](https://github.com/huggingface/search-and-learn)
- (TODO) deepscaler: a reproduction of [agentica-project/deepscaler](https://github.com/agentica-project/deepscaler)
- (TODO) OpenThinker: a reproduction of [open-thoughts/open-thoughts](https://github.com/open-thoughts/open-thoughts)
- (TODO) s1: a reproduction of [simplescaling/s1](https://github.com/simplescaling/s1) by Li Fei-Fei teams
59 changes: 59 additions & 0 deletions model-examples/reasoning-models/mini-r1-zero/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Mini R1-Zero with MindSpore

*A minimal reproduction of [DeepSeek R1 Zero](https://github.com/deepseek-ai/DeepSeek-R1) with native MindSpore.*


## Tutorials

See [train_a_mini_r1_zero_from_scratch.md](./tutorials-docs/train_a_mini_r1_zero_from_scratch.md)


## Features

- [x] grpo rl method
- [x] format reward, accuracy reward(countdown)
- [x] countdown game
- [x] base model: Qwen2.5-1.5B-Instruct
- [x] trainable on Ascend* device

- [ ] (TODO) large scale training
- [ ] (TODO) evaluation
- [ ] (TODO) visualization of results
- [ ] (TODO) ai-mo math task and reward


## Installation

```shell
pip install git+https://github.com/zhanghuiyao/mindone.git@add_qwen2
```


## Run Training

```shell
python train_r1_zero.py \
--model-path Qwen/Qwen2.5-1.5B-Instruct \
--dataset-path Jiayi-Pan/Countdown-Tasks-3to4 \
--max-completion-length 256 \
--bf16 \
--is-distribute False
```


## Acknowledge
* DeepSeek R1 [paper](https://arxiv.org/abs/2501.12948)
* DeepSeek Math [paper](https://arxiv.org/abs/2402.03300)
* We use Qwen2.5 series base model [Qwen2.5](https://github.com/QwenLM/Qwen2.5).


## Citation
```
@misc{mini-r1-zero-ms,
author = {mindspore-lab teams},
title = {Mini R1-Zero with MindSpore},
howpublished = {https://github.com/mindspore-lab/tutorials/model-examples/reasoning-models/mini-r1-zero},
note = {Accessed: 2025-02-12},
year = {2025}
}
```
92 changes: 92 additions & 0 deletions model-examples/reasoning-models/mini-r1-zero/src/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import numpy as np

from transformers import PreTrainedTokenizer
from datasets import load_dataset
from mindone.transformers.mindspore_adapter.data import HF2MSDataset

import mindspore


def create_countdown_dataset(
hf_dataset_path: str = "Jiayi-Pan/Countdown-Tasks-3to4",
tokenizer: PreTrainedTokenizer = None,
batch_size: int = 1,
num_epochs: int = 1,
rank: int = 0,
rank_size: int = 1,
):

# generate r1 prompt with a prefix for the model to already start with the thinking process

# TODO: remove <think>
def generate_r1_prompt(numbers, target):
r1_prefix = [{
"role": "system",
"content": "You are a helpful assistant. You first thinks about the reasoning process in the mind and then provides the user with the answer."
},
{
"role": "user",
"content": f"Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final equation and answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 = 1 </answer>."
},
{
"role": "assistant",
"content": "Let me solve this step by step.\n<think>"
}]
return {"prompt": tokenizer.apply_chat_template(r1_prefix, tokenize=False, continue_final_message=True),
"target": target}

# Load dataset from Hugging Face Hub
dataset = load_dataset(hf_dataset_path, split="train")
# select a random subset of 50k samples
dataset = dataset.shuffle(seed=42).select(range(50000))

# convert our dataset to the r1 prompt
dataset = dataset.map(lambda x: generate_r1_prompt(x["nums"], x["target"]))

# split the dataset into train and test
train_test_split = dataset.train_test_split(test_size=0.1)
train_dataset = train_test_split["train"]
test_dataset = train_test_split["test"]

# convert hf-datasets to mindspore-dataset

def ms_data_collator(inputs, batch_info):
first = inputs[0]
assert isinstance(first, dict)
prompts = [x["prompt"] for x in inputs]
nums = [x["nums"] for x in inputs]
targets = np.array([int(x["target"]) for x in inputs])
# prompt_inputs = tokenizer(prompts, return_tensors="np", padding=True, padding_side="left", add_special_tokens=False)
prompt_inputs = tokenizer(
prompts,
return_tensors="np",
padding="max_length",
truncation=True,
max_length=256,
padding_side="right",
add_special_tokens=False
)
batch = {
"prompts": prompts,
"nums": nums,
"targets": targets,
"prompt_ids": prompt_inputs.input_ids,
"attention_mask": prompt_inputs.attention_mask,
}
return batch

ms_train_dataset = mindspore.dataset.GeneratorDataset(
HF2MSDataset(train_dataset), column_names="item", shard_id=rank, num_shards=rank_size
)
ms_train_dataset = ms_train_dataset.batch(batch_size=batch_size, per_batch_map=ms_data_collator)
ms_train_dataset = ms_train_dataset.repeat(1)
ms_train_dataset = ms_train_dataset.create_dict_iterator(num_epochs=num_epochs, output_numpy=True)

ms_test_dataset = mindspore.dataset.GeneratorDataset(
HF2MSDataset(test_dataset), column_names="item", shard_id=rank, num_shards=rank_size
)
ms_test_dataset = ms_test_dataset.batch(batch_size=1, per_batch_map=ms_data_collator)
ms_test_dataset = ms_test_dataset.repeat(1)
ms_test_dataset = ms_test_dataset.create_dict_iterator(num_epochs=1, output_numpy=True)

return ms_train_dataset, ms_test_dataset
144 changes: 144 additions & 0 deletions model-examples/reasoning-models/mini-r1-zero/src/grpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import numpy as np
from typing import Callable, Optional, List
from transformers import PreTrainedTokenizer, GenerationConfig

import mindspore as ms
from mindspore import nn, ops, Tensor


class GRPO(nn.Cell):
def __init__(
self,
policy_model: Optional[nn.Cell],
reference_model: Optional[nn.Cell],
reward_funcs: List[Callable],
tokenizer: PreTrainedTokenizer,
args
):
super(GRPO, self).__init__()

self.policy_model = policy_model
self.reference_model = reference_model
self.reward_funcs = reward_funcs
self.tokenizer = tokenizer

# Training arguments
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
self.num_generations = args.num_generations # = G in the GRPO paper
self.beta = args.beta

self.generation_config = GenerationConfig(
max_new_tokens=args.max_completion_length,
do_sample=True,
temperature=args.temperature,
num_return_sequences=args.num_generations,
pad_token_id=tokenizer.pad_token_id,
)

def set_policy_train(self):

# 0. setting models' training mode
self.policy_model.set_train(True)
self.reference_model.set_train(False)
for rm in self.reward_funcs:
if isinstance(rm, nn.Cell):
rm.set_train(False)

def get_completion_and_reward(self, batch):
prompts, nums, targets, prompt_ids, attention_mask = \
batch["prompts"], batch["nums"], batch["targets"], batch["prompt_ids"], batch["attention_mask"]
prompts = np.array(prompts).tolist()
nums = [np.array(b).tolist() for b in batch["nums"]]

# FIXME: unpad
assert prompt_ids.shape[0] == 1, "not support bs>1 when generate task"
prompt_ids = prompt_ids[:, :attention_mask.sum()]
attention_mask = attention_mask[:, :attention_mask.sum()]

completion_ids = self.policy_model.generate(
input_ids=Tensor(prompt_ids, ms.int32),
attention_mask=Tensor(attention_mask, ms.bool_),
generation_config=self.generation_config,
max_new_tokens=self.max_completion_length,
use_cache=False,
)
completion_ids = completion_ids.asnumpy()
prompt_completion_ids = np.concatenate([prompt_ids.repeat(self.num_generations, axis=0), completion_ids], axis=-1)
num_logits_to_keep = completion_ids.shape[1]

# Mask everything after the first EOS token
is_eos = np.array(completion_ids == self.tokenizer.eos_token_id)
eos_idx = np.full((is_eos.shape[0],), is_eos.shape[1], dtype=np.int)
eos_idx[is_eos.any(axis=1)] = is_eos.astype(np.int).argmax(axis=1)[is_eos.any(axis=1)]
sequence_indices = np.arange(is_eos.shape[1])[None].repeat(is_eos.shape[0], axis=0)
completion_mask = (sequence_indices <= eos_idx[:, None]).astype(np.int)

# get reward
# decode the generated completions
completions = self.tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]

rewards_per_func = np.zeros((len(prompts), len(self.reward_funcs)), dtype=np.float32)
for i, reward_func in enumerate(self.reward_funcs):
rewards_per_func[:, i] = reward_func(completions=completions, nums=nums, targets=targets, prompt=prompts) # Shape (B*G,)
rewards = rewards_per_func.sum(axis=1)

return prompt_completion_ids, num_logits_to_keep, completion_mask, rewards

def compute_loss(
self,
prompt_completion_ids: Tensor,
num_logits_to_keep: int,
completion_mask: Tensor,
rewards: Tensor
) -> Tensor:

# 1. compute grpo reward and advantages
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(axis=1)
std_grouped_rewards = rewards.view(-1, self.num_generations).std(axis=1)

# normalize the rewards to compute the advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)

# 2. compute kl divergence
per_token_logps = self.get_log_probabilities(
prompt_completion_ids,
self.policy_model(prompt_completion_ids)[0],
num_logits_to_keep,
)
ref_per_token_logps = self.get_log_probabilities(
prompt_completion_ids,
self.reference_model(prompt_completion_ids)[0],
num_logits_to_keep,
)
kl_divergence = ops.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1

# x - x.detach() allows for preserving gradients from x
per_token_loss = ops.exp(per_token_logps - ops.stop_gradient(per_token_logps)) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - self.beta * kl_divergence)
loss = ((per_token_loss * completion_mask).sum(axis=1) / completion_mask.sum(axis=1)).mean()

return loss

@ms.jit
def get_log_probabilities(self, input_ids: Tensor, logits: Tensor, num_logits_to_keep: int) -> Tensor:
# logits: (B, L, V), prompt_completion_logits -> completion_logits
logits = logits[:, -(num_logits_to_keep+1):-1, :]

# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
per_token_logps = ()

for i in range(logits.shape[0]):
logits_row, input_ids_row = logits[i], input_ids[i, -num_logits_to_keep:]

log_probs = ops.log_softmax(logits_row, axis=-1)
token_log_prob = ops.gather_elements(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
per_token_logps += (token_log_prob,)

return ops.stack(per_token_logps)




62 changes: 62 additions & 0 deletions model-examples/reasoning-models/mini-r1-zero/src/rewards.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# reference to
# https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py
# https://github.com/philschmid/deep-learning-pytorch-huggingface/blob/main/training/mini-deepseek-r1-aha-grpo.ipynb
import re

import numpy as np


def format_reward(completions: list[str], *args, **kwargs) -> np.ndarray:
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"
# add synthetic <think> as its already part of the prompt and prefilled for the assistant to more easily match the regex
matches = [re.match(pattern, "<think>" + content, re.DOTALL | re.MULTILINE) for content in completions]
return np.array([1.0 if match else 0.0 for match in matches])


def countdown_game_accuracy_reward(completions: list[str], nums: int, targets: int, *args, **kwargs) -> np.ndarray:
"""
For Countdown Game, evaluates completions based on: Mathematical correctness of the answer

Args:
completions (list[str]): Generated outputs
targets (int): Expected answers
nums (int): Available numbers

Returns:
list[float]: Reward scores
"""
rewards = []
for completion, gt, numbers in zip(completions, targets, nums):
try:
# Check if the format is correct
match = re.search(r"<answer>(.*?)<\/answer>", completion)
if match is None:
rewards.append(0.0)
continue
# Extract the "answer" part from the completion
equation = match.group(1).strip()
# Extract all numbers from the equation
used_numbers = [int(n) for n in re.findall(r'\d+', equation)]

# Check if all numbers are used exactly once
if sorted(used_numbers) != sorted(numbers):
rewards.append(0.0)
continue
# Define a regex pattern that only allows numbers, operators, parentheses, and whitespace
allowed_pattern = r'^[\d+\-*/().\s]+$'
if not re.match(allowed_pattern, equation):
rewards.append(0.0)
continue

# Evaluate the equation with restricted globals and locals
result = eval(equation, {"__builti'ns__": None}, {})
# Check if the equation is correct and matches the ground truth
if abs(float(result) - float(gt)) < 1e-5:
rewards.append(1.0)
else:
rewards.append(0.0)
except Exception:
# If evaluation fails, reward is 0
rewards.append(0.0)
return np.array(rewards)
Loading