Skip to content

grpo liger loss #3781

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions requirements/install_all.sh
Original file line number Diff line number Diff line change
@@ -9,4 +9,6 @@ pip install timm -U
pip install deepspeed -U
pip install qwen_vl_utils qwen_omni_utils decord librosa pyav icecream soundfile -U
pip install liger_kernel nvitop pre-commit -U
pip install wandb
pip install math_verify==0.5.2
# flash-attn: https://github.com/Dao-AILab/flash-attention/releases
9 changes: 7 additions & 2 deletions swift/llm/argument/rlhf_args.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@

from swift.llm import MODEL_MAPPING
from swift.trainers.arguments import GRPOArgumentsMixin
from swift.utils import get_logger, is_master, set_default_ddp_config
from swift.utils import get_logger, is_liger_available, is_master, set_default_ddp_config
from .train_args import TrainArguments

logger = get_logger()
@@ -211,13 +211,18 @@ def _check_rlhf(self):
def _check_grpo(self):
if self.rlhf_type != 'grpo':
return

from packaging import version

import trl
trl_version = version.parse(trl.__version__)
assert trl_version >= version.parse('0.17'), ('Your current version of `trl` is outdated. '
'Please update it by running: pip install -U trl')

if self.use_liger_loss:
from trl.import_utils import is_liger_kernel_available
assert is_liger_kernel_available(), (
'Please install/update liger-kernel by running: pip install -U liger-kernel')

if self.num_generations < 2:
raise ValueError(
'GRPO requires at least 2 generations per prompt to calculate the advantages. You provided '
3 changes: 2 additions & 1 deletion swift/llm/template/base.py
Original file line number Diff line number Diff line change
@@ -1130,7 +1130,8 @@ def pre_forward_hook(self, model: nn.Module, args, kwargs):
old_kwargs = to_device(kwargs, model.device)
kwargs = to_device(self._post_encode(model, old_kwargs), model.device)
for k, v in old_kwargs.items():
if k in {'input_ids', 'attention_mask', 'labels', 'position_ids'} and k not in kwargs:
if k in {'input_ids', 'attention_mask', 'labels', 'position_ids', 'output_hidden_states'
} and k not in kwargs:
kwargs[k] = v
if 'inputs_embeds' in kwargs:
kwargs.pop('input_ids', None)
2 changes: 2 additions & 0 deletions swift/trainers/arguments.py
Original file line number Diff line number Diff line change
@@ -200,6 +200,8 @@ class GRPOArgumentsMixin:
# dataset
dataset_shuffle: Optional[bool] = True

use_liger_loss: bool = False


@dataclass
class TrainingArguments(SwiftArgumentsMixin, HfTrainingArguments):
2 changes: 1 addition & 1 deletion swift/trainers/rlhf_trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@
from .ppo_trainer import PPOTrainer
from .reward_trainer import RewardTrainer
from .rlhf_mixin import RLHFTrainerMixin
from .utils import _split_into_mini_batches, patch_lora_merge, patch_lora_unmerge, round_robin
from .utils import _split_into_mini_batches, patch_lora_merge, patch_lora_unmerge, round_robin, _ForwardRedirection
else:
_import_structure = {
'cpo_trainer': ['CPOTrainer'],
136 changes: 117 additions & 19 deletions swift/trainers/rlhf_trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
@@ -38,10 +38,11 @@
from swift.plugin import orms
from swift.plugin.multi_turn import multi_turns
from swift.utils import (JsonlWriter, gc_collect, get_device, get_device_count, get_dist_setting, get_logger,
get_node_setting, is_lmdeploy_available, is_vllm_available, is_wandb_available)
get_node_setting, is_liger_available, is_lmdeploy_available, is_vllm_available,
is_wandb_available)
from ..mixin import SwiftMixin
from .rlhf_mixin import RLHFTrainerMixin
from .utils import patch_lora_merge, patch_lora_unmerge, round_robin
from .utils import _ForwardRedirection, patch_lora_merge, patch_lora_unmerge, round_robin

del HFGRPOTrainer.__init__
del HFGRPOTrainer.log
@@ -165,7 +166,8 @@ def __init__(self,
else:
self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)

self.num_generations = args.num_generations
self.num_generations = args.num_generations # = G in the GRPO paper
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
self.temperature = args.temperature
self.loss_type = args.loss_type
model.warnings_issued['estimate_tokens'] = True
@@ -174,7 +176,14 @@ def __init__(self,

use_vllm = args.use_vllm
use_lmdeploy = args.use_lmdeploy

# we initialize vllm_client in RLHFArguments._init_external_vllm (swift/llm/rlhf_args)
vllm_client = kwargs.pop('vllm_client') # for external vllm
self.use_vllm = args.use_vllm
self.use_lmdeploy = args.use_lmdeploy
self.epsilon_low = args.epsilon
self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon
self.log_completions = args.log_completions
if self.args.tensor_parallel_size > 1 and self.multi_turn_func:
import torch.distributed as dist
rank, _, _, _ = get_dist_setting()
@@ -183,8 +192,29 @@ def __init__(self,
if rank in tp_group:
self.group = group

model.warnings_issued['estimate_tokens'] = True
kwargs['data_collator'] = lambda features: features

super().__init__(model, ref_model, *_args, **kwargs)

self.use_liger_loss = self.args.use_liger_loss
if self.use_liger_loss:
from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss
if not is_liger_available():
raise ImportError(
'Liger is required to use `liger_loss` as the GRPO loss. Run `pip install liger-kernel`.')

self.liger_grpo_loss = LigerFusedLinearGRPOLoss(
beta=self.beta,
epsilon_low=self.epsilon_low,
epsilon_high=self.epsilon_high,
temperature=self.temperature,
use_ref_model=self.beta != 0.0,
loss_type=self.loss_type,
max_completion_length=self.max_completion_length,
)
self._forward_redirection = _ForwardRedirection()

self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)}
self.log_completions = args.log_completions
self.wandb_log_unique_prompts = args.wandb_log_unique_prompts
@@ -252,7 +282,7 @@ def __init__(self,
'reducing it by one is sufficient. '
f'In your case: `--num_processes {get_device_count() - 1}`.')

if use_vllm:
if self.use_vllm:
if not is_vllm_available():
raise ImportError('vLLM is not available and `use_vllm` is set to True. '
'Please install vLLM with `pip install vllm -U` to use it.')
@@ -261,7 +291,7 @@ def __init__(self,
else:
self.engine = self.prepare_vllm(model, fast_infer_device)
self.infer_device = fast_infer_device[self.local_infer_rank]
elif use_lmdeploy:
elif self.use_lmdeploy:
if not is_lmdeploy_available():
raise ImportError('LMDeploy is not available and `use_lmdeploy` is set to True.'
'Please install LMDeploy with `pip install lmdeploy -U` to use it.')
@@ -317,8 +347,6 @@ def __init__(self,

# Multi-step
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
self.epsilon_low = args.epsilon
self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon

# Tracks the number of iterations (forward + backward passes), including those within a gradient accumulation cycle. # noqa
self._step = 0
@@ -580,7 +608,7 @@ def _move_model_to_vllm_lmdeploy(self):
if self.infer_rank >= 0:
if self.args.async_generate:
self._wait_queue()
if self.args.use_vllm:
if self.use_vllm:
llm_model = self.engine.inner_model
else:
llm_model = self.engine.engine.engine
@@ -593,7 +621,7 @@ def _move_model_to_vllm_lmdeploy(self):
with patch_lora_unmerge(unwrapped_model):
unwrapped_model.unmerge_adapter()

if self.infer_rank >= 0 and self.args.use_vllm and self.args.vllm_enable_prefix_caching:
if self.infer_rank >= 0 and self.use_vllm and self.args.vllm_enable_prefix_caching:
self.engine.engine.reset_prefix_cache()

def _wait_queue(self):
@@ -1056,15 +1084,6 @@ def _prepare_batch_inputs(self, inputs: InputsType, rewards: torch.Tensor) -> Li
batch_encoded_inputs['old_per_token_logps'] = (
self._get_per_token_logps(self.model, batch_encoded_inputs) if self.old_policy else None)

if self.beta == 0.0:
ref_per_token_logps = None
elif self.ref_model is not None:
ref_per_token_logps = self._get_per_token_logps(self.ref_model, batch_encoded_inputs)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
ref_per_token_logps = self._get_per_token_logps(self.model, batch_encoded_inputs)
batch_encoded_inputs['ref_per_token_logps'] = ref_per_token_logps

ga_batch_encoded_inputs.append(batch_encoded_inputs)

return ga_batch_encoded_inputs
@@ -1125,6 +1144,13 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
if isinstance(inputs, list):
assert len(inputs) == 1
inputs = inputs[0]
if self.use_liger_loss:
unwrapped_model = self.accelerator.unwrap_model(model)
return self._forward_redirection(model, unwrapped_model, self.compute_liger_loss, unwrapped_model, inputs)
else:
return self._compute_loss(model, inputs)

def _compute_loss(self, model, inputs):
completion_mask = inputs['completion_mask']
truncated_mask = inputs['truncated_mask']
# apply the completion_mask to exclude loss and metrics for overlong completions
@@ -1138,7 +1164,13 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N

# Compute the KL divergence between the model and the reference model
if self.beta != 0.0:
ref_per_token_logps = inputs['ref_per_token_logps']
with torch.no_grad():
if self.ref_model is not None:
ref_per_token_logps = self._get_per_token_logps(self.ref_model, inputs)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
ref_per_token_logps = self._get_per_token_logps(self.model, inputs)

per_token_kl = (
torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1)

@@ -1214,6 +1246,72 @@ def _get_per_token_logps(self, model, inputs):
input_ids = input_ids[:, -logits_to_keep:]
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens

@profiling_decorator
def _get_last_hidden_state(self, unwrapped_model, inputs, logits_to_keep):
# unwrap the model to access the model.model
if is_peft_model(unwrapped_model):
unwrapped_model = unwrapped_model.base_model.model
if not unwrapped_model.model_meta.is_multimodal:
last_hidden_state = unwrapped_model.model(
input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']).last_hidden_state
else:
inputs = {
k: v
for k, v in inputs.items() if k not in [
'logits_to_keep', 'completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps',
'truncated_mask'
]
}
with self._template_context(self.template):
outputs = unwrapped_model(**inputs, output_hidden_states=True)
last_hidden_state = outputs.hidden_states[-1]

last_hidden_state = last_hidden_state[:, :-1, :] # (B, L-1, H)
if logits_to_keep is not None:
last_hidden_state = last_hidden_state[:, -logits_to_keep:, :] # (B, logits_to_keep, H)
return last_hidden_state

def compute_liger_loss(self, unwrapped_model, inputs):
# Compute the per-token log probabilities for the model
input_ids = inputs['input_ids']
logits_to_keep = inputs['logits_to_keep']
completion_ids = input_ids[:, -logits_to_keep:]
completion_mask = inputs['completion_mask']

# Compute the KL divergence between the model and the reference model
ref_per_token_logps = None
if self.beta != 0.0:
with torch.no_grad():
if self.ref_model is not None:
ref_per_token_logps = self._get_per_token_logps(self.ref_model, inputs)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
ref_per_token_logps = self._get_per_token_logps(self.model, inputs)

# get the last hidden state of the model
last_hidden_state = self._get_last_hidden_state(unwrapped_model, inputs, logits_to_keep)
# compute loss and metrics using liger grpo loss
loss, metrics = self.liger_grpo_loss(
_input=last_hidden_state,
lin_weight=unwrapped_model.lm_head.weight,
selected_token_ids=completion_ids,
attention_mask=completion_mask,
advantages=inputs['advantages'],
bias=unwrapped_model.lm_head.bias,
old_per_token_logps=inputs['old_per_token_logps'],
ref_per_token_logps=ref_per_token_logps,
)
# Extract metrics from the liger_grpo_loss output
# KL divergence is the first metric when beta is non-zero
mean_kl = metrics[0] if self.beta != 0.0 else None
clip_ratio = metrics[-1]

mode = 'eval' if self.control.should_evaluate else 'train'
if self.beta != 0.0:
self._metrics[mode]['kl'].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
self._metrics[mode]['clip_ratio'].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
return loss

def evaluation_loop(self, dataloader, *args, **kwargs):
# Wait for the training rollout to complete
if self.args.async_generate:
46 changes: 46 additions & 0 deletions swift/trainers/rlhf_trainer/utils.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
import torch
from peft.tuners import lora
from peft.tuners.lora import LoraLayer
from torch import nn


def round_robin(num_reqs, num_workers):
@@ -130,3 +131,48 @@ def unmerge_patched(self):
del module.unmerge_origin
module._cache_pop = module._cache_pop_origin
del module._cache_pop_origin


class _ForwardRedirection:
"""Implements the `forward-redirection`.
Taken from Pytorch-lightning:
https://github.com/Lightning-AI/pytorch-lightning/blob/02311d03fb982560246eead7c08104481fac9579/src/lightning/pytorch/strategies/strategy.py#L602
A method call to a wrapped module gets rerouted through the wrapper's `forward` method instead.
"""

def __call__(self, wrapper_module: nn.Module, original_module: nn.Module, method: callable, *args: Any,
**kwargs: Any):
"""Reroutes a method call through the `wrapper_module`'s `forward` method.
Args:
wrapper_module: The module that has `original_module` wrapped.
original_module: The module that was wrapped inside `wrapper_module`.
method_name: The name of the method that should be called on the `original_module` after inputs get
redirected through the `wrapper_module`'s `forward` method.
*args: The positional arguments to the method `method_name`. They will get passed to a patched
`forward` method instead.
**kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched
`forward` method instead.
"""
original_forward = original_module.forward

def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any:
# Unpatch ourselves immediately before calling the method `method_name`
# because itself may want to call the real `forward`
original_module.forward = original_forward # type: ignore[method-assign]
# Call the actual method e.g. `.training_step(...)`
out = method(*_args, **_kwargs)
self.on_after_inner_forward(wrapper_module, original_module)
return out

# Patch the original_module's forward so we can redirect the arguments back to the real method
original_module.forward = wrapped_forward # type: ignore[method-assign]

wrapper_output = wrapper_module(*args, **kwargs)
self.on_after_outer_forward(wrapper_module, original_module)
return wrapper_output

def on_after_inner_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None:
pass

def on_after_outer_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None:
pass