diff --git a/docs/source/api_ref_modules.rst b/docs/source/api_ref_modules.rst index 4cf8dd0140..8be89a9e2b 100644 --- a/docs/source/api_ref_modules.rst +++ b/docs/source/api_ref_modules.rst @@ -80,6 +80,7 @@ Loss :toctree: generated/ :nosignatures: + loss.PPOLoss loss.DPOLoss loss.RSOLoss loss.IPOLoss @@ -98,3 +99,17 @@ Functions used for preprocessing images. transforms.tile_crop transforms.find_supported_resolutions transforms.VisionCrossAttentionMask + +Reinforcement Learning From Human Feedback (RLHF) +-------------------------------------------------- +Components for RLHF algorithms like PPO. + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + rlhf.estimate_advantages + rlhf.get_rewards_ppo + rlhf.truncate_sequence_at_first_stop_token + rlhf.left_padded_collate + rlhf.padded_collate_dpo diff --git a/docs/source/api_ref_utilities.rst b/docs/source/api_ref_utilities.rst index ca0b19a30b..4d64b9bf26 100644 --- a/docs/source/api_ref_utilities.rst +++ b/docs/source/api_ref_utilities.rst @@ -115,7 +115,6 @@ Utilities for working with data and datasets. :nosignatures: padded_collate - padded_collate_dpo .. _gen_label: diff --git a/recipes/configs/mistral/7B_full_ppo_low_memory.yaml b/recipes/configs/mistral/7B_full_ppo_low_memory.yaml new file mode 100644 index 0000000000..7ccf510dff --- /dev/null +++ b/recipes/configs/mistral/7B_full_ppo_low_memory.yaml @@ -0,0 +1,180 @@ +# Config for single device RLHF full finetuning using PPO in ppo_full_finetune_single_device.py +# using a Mistral 7B model. +# +# This config has been tested on an A100 80GB. +# This config uses hyperparameters based on small set of experiments and information +# available from existing implementations. +# +# This config assumes that you've run the following command before launching +# this run: +# tune download weqweasdas/RM-Mistral-7B --output-dir /tmp/RM-Mistral-7B/ --ignore-patterns="" +# tune download mistralai/Mistral-7B-Instruct-v0.2 --output-dir /tmp/Mistral-7B-Instruct-v0.2/ --hf-token HF_TOKEN +# +# You'll also need to ensure that {output_dir} exists beforehand, as checkpoints for policy and value models are saved in sub-folders. +# The default config uses an optimizer from bitsandbytes. If you do not have it installed, +# you can install it with +# pip install bitsandbytes +# +# To launch on a single device, run the following command from root: +# tune run ppo_full_finetune_single_device --config mistral/7B_full_ppo_low_memory +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run ppo_full_finetune_single_device --config mistral/7B_full_low_memory checkpointer.checkpoint_dir= +# + +# Tokenizer +tokenizer: + _component_: torchtune.models.mistral.mistral_tokenizer + path: /tmp/Mistral-7B-Instruct-v0.2/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.text_completion_dataset + source: trl-internal-testing/sentiment-trl-style + max_seq_len: null + split: train + column: prompt + add_eos: False + +policy_model: + _component_: torchtune.models.mistral.mistral_7b + +# we need to manually build the mistral classifier model +# because our reward model checkpoint has a larger vocabulary size (due to an added padding token) +reward_and_value_model: + _component_: torchtune.models.mistral._component_builders.mistral_classifier + attn_dropout: 0.0 + embed_dim: 4096 + intermediate_dim: 14336 + max_seq_len: 32768 + norm_eps: 1.0e-05 + num_classes: 1 + num_heads: 32 + num_kv_heads: 8 + num_layers: 32 + vocab_size: 32001 + +# checkpointer for the policy model - update this if resuming from checkpoint +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Mistral-7B-Instruct-v0.2/ + checkpoint_files: [ + "pytorch_model-00001-of-00003.bin", + "pytorch_model-00002-of-00003.bin", + "pytorch_model-00003-of-00003.bin" + ] + # this is the only place where you should update `recipe_checkpoint` if resuming training + recipe_checkpoint: null + output_dir: ${output_dir}/policy + model_type: MISTRAL + +# this should be setup identically to the policy model checkpointer at the start of training +# ensure `checkpoint_files` always points to the original policy weights, even if resuming training +ref_policy_checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Mistral-7B-Instruct-v0.2/ + checkpoint_files: [ + "pytorch_model-00001-of-00003.bin", + "pytorch_model-00002-of-00003.bin", + "pytorch_model-00003-of-00003.bin" + ] + output_dir: ${output_dir}/policy + model_type: MISTRAL + +# checkpointer for the value model - update `checkpoint_files` if resuming from checkpoint +# since this model will be identical to the reward model it's helpful to initialise this +# from the trained reward model weights +value_checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/RM-Mistral-7B/ + checkpoint_files: [ + "model-00001-of-00003.safetensors", + "model-00002-of-00003.safetensors", + "model-00003-of-00003.safetensors" + ] + output_dir: ${output_dir}/value + model_type: REWARD + +# checkpointer for the reward model, ensure `checkpoint_files` +# always points to the original reward model weights, even if resuming training +reward_checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/RM-Mistral-7B/ + checkpoint_files: [ + "model-00001-of-00003.safetensors", + "model-00002-of-00003.safetensors", + "model-00003-of-00003.safetensors" + ] + output_dir: ${output_dir}/value + model_type: REWARD + + +resume_from_checkpoint: False +output_dir: /tmp/mistral7b-ppo-finetune +seed: null +shuffle: True + +# Training env +device: cuda + +# Training arguments +batch_size: 64 +num_steps: 10000 +ppo_epochs: 2 +ppo_batch_size: 32 +gradient_accumulation_steps: 1 + +# Memory management and performance +compile: True +optimizer: + _component_: bitsandbytes.optim.PagedAdamW + lr: 3e-6 +optimizer_in_bwd: True +log_peak_memory_stats: False +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + + +# batch size for forward pass during generation +forward_batch_size: 16 +max_generated_tokens: 58 +temperature: 0.7 +top_k: null + +# parameter for penalising generations shorter than `min_response_length` +min_response_length: 18 +# parameter for penalising generations without a stop token +penalise_no_eos: True +# scalar penalty to apply when penalising +reward_penalty: -3 + +# tokens to consider as "end of sequence" tokens +stop_token_ids: [ + 2, # eos_id + 28723 # mistral "." token +] +whiten_rewards: False + +# GAE hyperparameters +gamma: 1 +lmbda: 0.95 + +# PPO hyperparameters +loss: + _component_: torchtune.modules.loss.PPOLoss + epsilon: 0.2 + value_coeff: 0.1 + value_clip_range: 0.2 +kl_coeff: 0.01 + + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} + +log_every_n_steps: 1 diff --git a/recipes/lora_dpo_distributed.py b/recipes/lora_dpo_distributed.py index d5791fc211..ff277735b6 100644 --- a/recipes/lora_dpo_distributed.py +++ b/recipes/lora_dpo_distributed.py @@ -27,6 +27,7 @@ from torchtune import config, modules, utils from torchtune.data import CROSS_ENTROPY_IGNORE_IDX from torchtune.datasets import ConcatDataset +from torchtune.modules import rlhf from torchtune.modules.peft.peft_utils import ( disable_adapter, get_adapter_params, @@ -449,7 +450,7 @@ def _setup_data( batch_size=batch_size, sampler=sampler, collate_fn=partial( - utils.padded_collate_dpo, + rlhf.padded_collate_dpo, padding_idx=self._tokenizer.pad_id, ignore_idx=CROSS_ENTROPY_IGNORE_IDX, ), diff --git a/recipes/lora_dpo_single_device.py b/recipes/lora_dpo_single_device.py index dbb64293d0..3af93bad53 100644 --- a/recipes/lora_dpo_single_device.py +++ b/recipes/lora_dpo_single_device.py @@ -19,6 +19,7 @@ from torchtune import config, modules, utils from torchtune.data import CROSS_ENTROPY_IGNORE_IDX from torchtune.datasets import ConcatDataset +from torchtune.modules import rlhf from torchtune.modules.peft.peft_utils import ( disable_adapter, get_adapter_params, @@ -345,7 +346,7 @@ def _setup_data( sampler=sampler, batch_size=batch_size, collate_fn=partial( - utils.padded_collate_dpo, + rlhf.padded_collate_dpo, padding_idx=self._tokenizer.pad_id, ignore_idx=CROSS_ENTROPY_IGNORE_IDX, ), diff --git a/recipes/ppo_full_finetune_single_device.py b/recipes/ppo_full_finetune_single_device.py new file mode 100644 index 0000000000..1f4e68b41a --- /dev/null +++ b/recipes/ppo_full_finetune_single_device.py @@ -0,0 +1,1078 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +import os +import sys +from functools import partial +from itertools import chain +from typing import Any, Dict, List, Optional, Tuple +from warnings import warn + +import torch +from omegaconf import DictConfig, ListConfig +from torch import nn +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler +from torchtune import config, modules, utils +from torchtune.datasets import ConcatDataset +from torchtune.modules import rlhf +from torchtune.modules.rlhf import PPOStats, Trajectory +from torchtune.recipe_interfaces import FTRecipeInterface +from tqdm import tqdm + + +log = utils.get_logger("DEBUG") + + +class PPOFullFinetuneRecipeSingleDevice(FTRecipeInterface): + """ + Full finetuning recipe for RLHF with PPO for dense transformer-based LLMs such as LLama2. This recipe is optimized + for single GPU training. Training on CPU is not supported. + + This implementation is based on `Learning to summarize from human feedback None: + + self._device = utils.get_device(device=cfg.device) + self._dtype = utils.get_dtype(cfg.dtype, device=self._device) + + # Disable for fp16, as we haven't validated "full" fp16 with this recipe, nor + # enabled necessary features such as gradient scaling. + if self._dtype == torch.float16: + raise RuntimeError( + "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." + ) + + # logging attributes + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.get("log_every_n_steps", 1) + self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + + # These are public properties which are updated by the checkpoint loader + # when ``resume_from_checkpoint`` is `True` or validated in tests + self.seed = utils.set_seed(seed=cfg.seed) + # manually setting up a generator for the recipe + self._rng = torch.Generator(self._device).manual_seed(self.seed) + self._total_steps = 0 + self._steps_run = 0 + self._total_epochs = 0 + self._epochs_run = 0 + self.global_step = 0 + + # Training cfg + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + + def setup(self, cfg: DictConfig) -> None: + """ + Sets up the recipe state correctly. This includes setting recipe attributes based + on the ``resume_from_checkpoint`` flag. + """ + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) + + # setup checkpointers + ( + self._policy_checkpointer, + ref_policy_checkpointer, + self._value_checkpointer, + reward_checkpointer, + ) = self._setup_checkpointers( + cfg.checkpointer, + cfg.ref_policy_checkpointer, + cfg.value_checkpointer, + cfg.reward_checkpointer, + ) + + # load policy checkpoints + policy_model_checkpoint_dict = self._policy_checkpointer.load_checkpoint() + ref_policy_state_dict = ref_policy_checkpointer.load_checkpoint() + + # load reward and value model checkpoints + value_model_checkpoint_dict = self._value_checkpointer.load_checkpoint() + reward_model_state_dict = reward_checkpointer.load_checkpoint() + + # update recipe state + # ``_setup_model`` handles initialization and loading the state dict. This method + # should be called before ``_setup_optimizer`` since transforming the optimizer + # state dict requires the model + self._model_compile = cfg.compile + self._optimizer_in_bwd = cfg.optimizer_in_bwd + ( + self._policy_model, + self._value_model, + self._reward_model, + self._ref_policy_model, + ) = self._setup_model( + cfg_model=cfg.policy_model, + cfg_reward_value_model=cfg.reward_and_value_model, + enable_activation_checkpointing=cfg.enable_activation_checkpointing, + compile_model=self._model_compile, + policy_state_dict=policy_model_checkpoint_dict[utils.MODEL_KEY], + ref_policy_state_dict=ref_policy_state_dict[utils.MODEL_KEY], + value_model_state_dict=value_model_checkpoint_dict[utils.MODEL_KEY], + reward_model_state_dict=reward_model_state_dict[utils.MODEL_KEY], + ) + + # setup tokenizer + self._tokenizer = config.instantiate(cfg.tokenizer) + log.info("Tokenizer is initialized from file.") + + # _setup_optimizer should take in ckpt_dict only if training is resumed from + # checkpoint. Transforming the opt state dict is handled by this method + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + optimizer_in_bwd=cfg.optimizer_in_bwd, + opt_state_dict=( + policy_model_checkpoint_dict[utils.OPT_KEY] + if self._resume_from_checkpoint + else None + ), + ) + + self._loss_fn = config.instantiate(cfg.loss) + log.info("Loss is initialized.") + + # sampler and dataloader depends on the tokenizer and should be set + # setup afterit is initialized + self._sampler, self._dataloader = self._setup_data( + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + ) + + self._setup_training_parameters(cfg) + self._setup_training_hyperparameters(cfg) + + if self._resume_from_checkpoint: + self._update_recipe_state(policy_model_checkpoint_dict) + + # one "step" is a single gradient update update over a minibatch of trajectories + self.global_step = ( + self._steps_run + * self._ppo_epochs + * (self.batch_size // self._ppo_batch_size) + ) + + def _setup_training_hyperparameters(self, cfg) -> None: + """ + Sets up the training hyperparameters for the recipe. This includes the GAE hyperparameters, + generation hyperparameters, reward masking hyperparameters, and stop token ids. + """ + + self._kl_coeff = cfg.kl_coeff + # GAE hyperparameters + self._gamma = cfg.gamma + self._lmbda = cfg.lmbda + self._whiten_rewards = cfg.whiten_rewards + + # trajectory generation args + self._temperature = cfg.temperature + self._top_k = cfg.top_k + self._max_generated_tokens = cfg.max_generated_tokens + + # reward masking args + self._min_response_length = cfg.min_response_length + self._penalise_no_eos = cfg.penalise_no_eos + self._reward_penalty = cfg.reward_penalty + + # lots of hand holding for stop tokens + if cfg.get("stop_token_ids", False): + stop_token_ids = cfg.stop_token_ids + if self._tokenizer.eos_id not in stop_token_ids: + warn( + f"tokenizer eos_id ({self._tokenizer.eos_id}) is not in stop_token_ids ({stop_token_ids})." + "This may lead to unexpected behaviour." + ) + else: + if not hasattr(self._tokenizer.stop_tokens): + warn( + "No stop tokens defined in tokenizer, and no stop_token_ids provided. This may lead to unexpected behaviour." + ) + stop_token_ids = [] + else: + stop_token_ids = self._tokenizer.stop_tokens + self._stop_token_ids = torch.tensor(stop_token_ids, device=self._device) + + def _setup_training_parameters(self, cfg: DictConfig) -> None: + """ + Validates and sets up parameters for used during training and for tracking training state, + batch sizes for model forward passes during trajectory generation, PPO minibatches, and + PPO microbatches for gradient accumulation. + + Raises + - ValueError if: + - batch_size is not divisible by forward_batch_size + - batch_size is not divisible by ppo_batch_size + - ppo_batch_size is not divisible by gradient_accumulation_steps + - num_steps is less than batch_size + - gradient_accumulation_steps > 1 and optimizer_in_bwd is True + """ + self.batch_size = cfg.batch_size + self._forward_batch_size = cfg.forward_batch_size + self._ppo_epochs = cfg.ppo_epochs + self._ppo_batch_size = cfg.ppo_batch_size + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + self._ppo_backward_batch_size = ( + cfg.ppo_batch_size // self._gradient_accumulation_steps + ) + + if self.batch_size % self._forward_batch_size != 0: + raise ValueError( + f"batch_size ({self.batch_size}) must be exactly divisible by " + f"forward_batch_size ({self._forward_batch_size})." + ) + if self.batch_size % self._ppo_batch_size != 0: + raise ValueError( + f"batch_size ({self.batch_size}) must be exactly divisible by " + f"ppo_batch_size ({self._ppo_batch_size})." + ) + if self._ppo_batch_size % self._gradient_accumulation_steps != 0: + raise ValueError( + f"ppo_batch_size ({self._ppo_batch_size}) must be exactly divisible " + f"by gradient_accumulation_steps ({self._gradient_accumulation_steps})." + ) + + if self._gradient_accumulation_steps > 1 and self._optimizer_in_bwd: + raise RuntimeError( + "Gradient accumulation is not supported with optimizer in bwd." + "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False." + ) + + self._total_steps = cfg.num_steps // self.batch_size + batches_per_epoch = max( + 1, len(self._dataloader) + ) # when we only have a single batch in the dataset + + self._total_epochs = math.ceil(self._total_steps / batches_per_epoch) + if self._total_steps == 0: + raise ValueError( + f"num_steps {cfg.num_steps} must be greater than the batch size {self.batch_size}." + ) + if self._total_steps < len(self._dataloader): + warn( + f"There are fewer total steps ({self._total_steps}, (num_steps//batch_size) " + f"than there are batches ({len(self._dataloader)}) in the dataset. " + f"Training will stop after ({self._total_steps}) steps without saving intermediate checkpoints" + ) + if (self._total_steps > batches_per_epoch) and ( + self._total_steps % batches_per_epoch != 0 + ): + warn( + f"num_steps ({cfg.num_steps}) is not exactly divisible by " + f"the number of batches in the dataset ({batches_per_epoch}). " + f"Intermediate checkpoints will only be saved every {batches_per_epoch} steps." + ) + log.info( + f"Total steps to run: {self._total_steps}, Total epochs to run: {self._total_epochs}" + ) + + def _setup_checkpointers( + self, + policy_cfg: DictConfig, + ref_policy_cfg: DictConfig, + value_cfg: DictConfig, + reward_cfg: DictConfig, + ) -> Tuple[ + utils.Checkpointer, utils.Checkpointer, utils.Checkpointer, utils.Checkpointer + ]: + """ + Sets up checkpointers for policy, reference policy, value, and reward models. + Only the policy checkpoint handles recipe state for resuming from checkpoints. + """ + + if not self._resume_from_checkpoint: + assert policy_cfg.checkpoint_dir == ref_policy_cfg.checkpoint_dir, ( + "Policy and reference policy should be loaded from the same checkpoint directories" + f"at the start of training. Found: {policy_cfg.checkpoint_dir} and" + f"{ref_policy_cfg.checkpoint_dir}" + ) + assert policy_cfg.checkpoint_files == ref_policy_cfg.checkpoint_files, ( + "Policy and reference policy should be loaded from the same checkpoint files" + f"at the start of training. Found: {policy_cfg.checkpoint_files} and" + f"{ref_policy_cfg.checkpoint_files}" + ) + + policy_checkpointer = config.instantiate( + policy_cfg, + resume_from_checkpoint=self._resume_from_checkpoint, + ) + + ref_policy_checkpointer = config.instantiate( + ref_policy_cfg, + resume_from_checkpoint=False, + ) + + value_checkpointer = config.instantiate( + value_cfg, + resume_from_checkpoint=False, + ) + + reward_checkpointer = config.instantiate( + reward_cfg, + resume_from_checkpoint=False, + ) + + return ( + policy_checkpointer, + ref_policy_checkpointer, + value_checkpointer, + reward_checkpointer, + ) + + def _setup_model( + self, + cfg_model: DictConfig, + cfg_reward_value_model: DictConfig, + enable_activation_checkpointing: bool, + compile_model: bool, + policy_state_dict: Dict[str, Any], + ref_policy_state_dict: Dict[str, Any], + value_model_state_dict: Dict[str, Any], + reward_model_state_dict: Dict[str, Any], + ) -> Tuple[nn.Module, nn.Module, nn.Module]: + """ + Sets up the policy model, reference policy model, reward model, and value model. + """ + + with utils.set_default_dtype(self._dtype), self._device: + policy_model = config.instantiate(cfg_model) + ref_policy_model = config.instantiate(cfg_model) + reward_model = config.instantiate(cfg_reward_value_model) + value_model = config.instantiate(cfg_reward_value_model) + + if enable_activation_checkpointing: + utils.set_activation_checkpointing( + policy_model, auto_wrap_policy={modules.TransformerDecoderLayer} + ) + utils.set_activation_checkpointing( + value_model, auto_wrap_policy={modules.TransformerDecoderLayer} + ) + + policy_model.load_state_dict(policy_state_dict) + ref_policy_model.load_state_dict(ref_policy_state_dict) + + reward_missing, reward_unexpected = reward_model.load_state_dict( + reward_model_state_dict, strict=False + ) + value_missing, value_unexpected = value_model.load_state_dict( + value_model_state_dict, strict=False + ) + + # some extra validation for HF classifier checkpoints with a `score.bias` present + assert ( + reward_missing == value_missing == [] + ), f"Missing keys in reward ({reward_missing}) and value model ({value_missing}) state dicts." + + if reward_unexpected or value_unexpected: + # the only unexpected keys should be when pre-trained HF models were saved with + # bias=True in final classification layers. This happens when training a reward model with TRL. + assert ( + reward_unexpected == value_unexpected == ["output.bias"] + ), f"Unexpected keys in reward ({reward_unexpected}) and value model ({value_unexpected}) state dicts." + + # Validate models were loaded in with the expected dtype. + utils.validate_expected_param_dtype( + value_model.named_parameters(), dtype=self._dtype + ) + utils.validate_expected_param_dtype( + reward_model.named_parameters(), dtype=self._dtype + ) + utils.validate_expected_param_dtype( + value_model.named_parameters(), dtype=self._dtype + ) + utils.validate_expected_param_dtype( + ref_policy_model.named_parameters(), dtype=self._dtype + ) + + log.info(f"Models are initialized with precision {self._dtype}.") + + # disabling dropout if found - non-determinism leads to issues in e.g. comparing logprobs + # between ref policy and current policy + for module in policy_model.modules(): + if isinstance(module, torch.nn.Dropout): + warn( + f"Dropout found in {module}. This is likely to cause issues during training. Disabling." + ) + module.p = 0 + for module in value_model.modules(): + if isinstance(module, torch.nn.Dropout): + warn( + f"Dropout found in {module}. This is likely to cause issues during training. Disabling." + ) + module.p = 0 + + # disabling grad and dropout in reward and reference policy models + reward_model.eval() + ref_policy_model.eval() + + for p in reward_model.parameters(): + p.requires_grad = False + + for p in ref_policy_model.parameters(): + p.requires_grad = False + + # Compile model, if enabled. + if compile_model: + backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") + log.info("Compiling models torch.compile...") + + policy_model.compile(backend=backend) + reward_model.compile(backend=backend) + ref_policy_model.compile(backend=backend) + value_model.compile(backend=backend) + + if self._device.type == "cuda": + memory_stats = utils.get_memory_stats(device=self._device) + utils.log_memory_stats(memory_stats) + + return policy_model, value_model, reward_model, ref_policy_model + + def _setup_optimizer( + self, + cfg_optimizer: DictConfig, + optimizer_in_bwd: bool = False, + opt_state_dict: Optional[Dict[str, Any]] = None, + ) -> Optimizer: + + if optimizer_in_bwd: + # Maintain a dict of optims for every parameter. + optim_dict = { + p: config.instantiate(cfg_optimizer, [p]) + for p in chain( + self._policy_model.parameters(), self._value_model.parameters() + ) + } + # Register optimizer step hooks on the models to run optimizer in backward. + utils.register_optim_in_bwd_hooks( + model=self._policy_model, optim_dict=optim_dict + ) + utils.register_optim_in_bwd_hooks( + model=self._value_model, optim_dict=optim_dict + ) + # Create a wrapper for checkpoint save/load of optimizer states when running in backward. + self._optim_ckpt_wrapper = utils.create_optim_in_bwd_wrapper( + model=self._policy_model, optim_dict=optim_dict + ) + self._optim_ckpt_wrapper = utils.create_optim_in_bwd_wrapper( + model=self._value_model, optim_dict=optim_dict + ) + # Load optimizer states. If optimizer states are being restored in an optimizer in backward + # run, these need to have been saved with the same setting. Cannot restore from runs that did not + # use optimizer in backward. + if opt_state_dict is not None: + try: + self._optim_ckpt_wrapper.load_state_dict(opt_state_dict) + except BaseException as e: + raise RuntimeError( + "Failed loading in-backward optimizer checkpoints." + "Please make sure run being restored from was using in-backward optimizer." + ) from e + log.info("In-backward optimizers are set up.") + return None + else: + optimizer = config.instantiate( + cfg_optimizer, + chain(self._policy_model.parameters(), self._value_model.parameters()), + ) + if opt_state_dict: + optimizer.load_state_dict(opt_state_dict) + + log.info("Optimizer is initialized.") + return optimizer + + def _setup_data( + self, cfg_dataset: DictConfig, shuffle: bool, batch_size: int + ) -> Tuple[DistributedSampler, DataLoader]: + """ + All data related setup happens here. + """ + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + else: + ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + + sampler = DistributedSampler( + ds, + num_replicas=1, + rank=0, + shuffle=shuffle, + seed=0, + ) + dataloader = DataLoader( + dataset=ds, + sampler=sampler, + batch_size=batch_size, + collate_fn=partial( + rlhf.left_padded_collate, + padding_idx=self._tokenizer.pad_id, + ), + drop_last=True, + ) + + return sampler, dataloader + + def save_checkpoint( + self, epoch: int, is_intermediate_checkpoint: bool = False + ) -> None: + """ + Save state dict to file. The recipe save_checkpoint method is responsible for + correctly creating the checkpoint dict and passing to the checkpointer. + """ + policy_ckpt_dict = {utils.MODEL_KEY: self._policy_model.state_dict()} + value_ckpt_dict = {utils.MODEL_KEY: self._value_model.state_dict()} + + # if training is in-progress, checkpoint the optimizer state and rng state as well + if is_intermediate_checkpoint: + policy_ckpt_dict.update( + { + utils.SEED_KEY: self.seed, + utils.EPOCHS_KEY: self._epochs_run, + utils.TOTAL_EPOCHS_KEY: self._total_epochs, + utils.MAX_STEPS_KEY: self._total_steps, + utils.STEPS_KEY: self._steps_run, + utils.RNG_KEY: self._rng.get_state(), + } + ) + if not self._optimizer_in_bwd: + policy_ckpt_dict[utils.OPT_KEY] = self._optimizer.state_dict() + else: + policy_ckpt_dict[utils.OPT_KEY] = self._optim_ckpt_wrapper.state_dict() + + self._policy_checkpointer.save_checkpoint( + policy_ckpt_dict, + epoch=epoch, + intermediate_checkpoint=is_intermediate_checkpoint, + ) + + self._value_checkpointer.save_checkpoint( + value_ckpt_dict, + epoch=epoch, + intermediate_checkpoint=False, + ) + + def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + # If seed or total_steps, or total_epochs don't match, + # warn the user and overwrite. + try: + if ( + self.seed != ckpt_dict[utils.SEED_KEY] + or self._total_steps != ckpt_dict[utils.MAX_STEPS_KEY] + or self._total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY] + ): + warn( + message="""Configured value for seed, total_steps, or total_epochs + does not match the value stored in checkpoint.""" + ) + self.seed = utils.set_seed(seed=ckpt_dict[utils.SEED_KEY]) + self._rng.set_state(ckpt_dict[utils.RNG_KEY]) + self._steps_run = ckpt_dict[utils.STEPS_KEY] + self._total_steps = ckpt_dict[utils.MAX_STEPS_KEY] + self._total_epochs = ckpt_dict[utils.TOTAL_EPOCHS_KEY] + self._epochs_run = ckpt_dict[utils.EPOCHS_KEY] + + except KeyError as e: + raise KeyError from e( + "Checkpoint does not contain the required keys needed for updating recipe state." + "Are you sure you passed in the right recipe checkpoint?" + ) + + def generate_trajectory(self, input_ids: torch.Tensor) -> Trajectory: + """ + Generates a trajectory given the current policy and value models, the reference policy model, the reward model, + and batch of inputs. This is done over the following steps: + + 1: Generate responses, and logits corresponding to the responses using the current policy, + generating (query, response) pairs. + 2. Estimate logprobs of the generated responses using the current policy. + 3. Estimate values from the generated responses using the current value function. + 4. Replace any tokens in the response after the first stop token (usually EOS token) with padding, + producting truncated responses. + 5. Run the reward model on the (query, truncated-response) pairs. + 6. Mask out all the invalid values in the trajectory due to padding tokens. + + Args: + input_ids (torch.Tensor): tensor of input token IDs with shape [b, seq_length] + + Returns: + Trajectory: An instance of :class:`~torchtune.modules.rlhf.Trajectory` comprising + the current trajectory. + """ + batch_size, context_length = input_ids.shape + + # step 1: generate responses, and logits corresponding to the responses using the current policy + query_responses, logits = rlhf.generate_with_logits( + model=self._policy_model, + prompt=input_ids, + max_generated_tokens=self._max_generated_tokens, + temperature=self._temperature, + top_k=self._top_k, + pad_id=self._tokenizer.pad_id, + rng=self._rng, + ) + + responses = query_responses[:, context_length:].clone() + query_response_padding_masks = query_responses == self._tokenizer.pad_id + + # step 1.1 create attention masks and position IDs for any padding tokens in inputs, used for future forward passes + masks = rlhf.get_causal_mask(~(query_response_padding_masks)) + position_ids = (~query_response_padding_masks).cumsum(-1) - ( + ~query_response_padding_masks + ).long() + position_ids = position_ids.type(torch.int) + + del query_response_padding_masks + + # step 2. estimate logprobs of the responses using the current policy + logits = logits[:, context_length - 1 :] + logprobs = rlhf.logits_to_logprobs(logits, responses, self._temperature) + + del logits + + # step 2.1 estimate logprobs of the responses using the reference policy + ref_logits = self._ref_policy_model( + query_responses, input_pos=position_ids, mask=masks + ) + ref_logits = rlhf.truncate_sequence_for_logprobs(ref_logits, context_length) + ref_logprobs = rlhf.logits_to_logprobs(ref_logits, responses, self._temperature) + + del ref_logits + + # step 3. estimate values from the responses using the value function + values = self._value_model(query_responses, input_pos=position_ids, mask=masks) + values = rlhf.truncate_sequence_for_logprobs(values, context_length).squeeze(-1) + + # step 4. replace any tokens in the responses after the first stop token (usually EOS token) with padding + # resulting in truncated responses + response_padding_masks, responses = rlhf.truncate_sequence_at_first_stop_token( + responses, self._stop_token_ids, self._tokenizer.pad_id + ) + + # step 5. run the reward model on the (query, truncated-response) pairs + scores = self._reward_model( + torch.cat([input_ids, responses], dim=1), + input_pos=position_ids, + mask=masks, + ) + + del responses + + # step 5.1 the scores from the reward model are the logits for the last non-padding token in + # each (query, truncated-response) pair + seq_lens = utils.get_unmasked_sequence_lengths(response_padding_masks) + scores = scores[torch.arange(batch_size), seq_lens + context_length].squeeze(-1) + + # step 5.2 if configured, apply any penalties for sequences without EOS tokens + # or shorter than a certain length + if self._penalise_no_eos or self._min_response_length: + reward_penalty_mask = rlhf.get_reward_penalty_mask( + response_padding_masks, + seq_lens, + self._penalise_no_eos, + self._min_response_length, + ) + scores[reward_penalty_mask] = self._reward_penalty + + # step 6. mask out all the invalid values in the trajectory due to padding tokens + logprobs[response_padding_masks] = 1.0 + ref_logprobs[response_padding_masks] = 1.0 + + # step 6.1 values are masked out *after* the last valid token in the response + value_seq_idxs = torch.where( + (seq_lens > 0) & (seq_lens < self._max_generated_tokens - 1), + seq_lens + 1, + seq_lens, + ) + value_padding_masks = response_padding_masks.clone() + value_padding_masks[ + torch.arange(batch_size, device=value_padding_masks.device), + value_seq_idxs, + ] = False + + values[value_padding_masks] = 0.0 + + return Trajectory( + query_responses=query_responses, + logprobs=logprobs, + ref_logprobs=ref_logprobs, + values=values, + masks=masks, + position_ids=position_ids, + response_padding_masks=response_padding_masks, + value_padding_masks=value_padding_masks, + value_seq_idxs=value_seq_idxs, + scores=scores, + seq_lens=seq_lens, + ) + + def generate_trajectory_batched(self, input_ids: torch.Tensor) -> Trajectory: + """ + Generates a ``self.batch_size`` batch of trajectories using `self._forward_batch_size` batch sizes. + See ``generate_trajectory`` for more details. + + Args: + input_ids (torch.Tensor): tensor of input token IDs with shape [b, seq_length] + + Returns: + Trajectory: An instance of :class:`~torchtune.modules.rlhf.Trajectory`, comprising + the current trajectory. + """ + trajectories: List[Trajectory] = [] + with torch.no_grad(): + for batch_start in range(0, self.batch_size, self._forward_batch_size): + batch_input_ids = input_ids[ + batch_start : batch_start + self._forward_batch_size + ] + trajectories.append(self.generate_trajectory(batch_input_ids)) + return Trajectory(*map(torch.cat, zip(*trajectories))) + + def train(self) -> None: + """ + The core training loop.""" + + if self._model_compile: + log.info( + "NOTE: torch.compile is enabled and model is compiled in first forward." + "Expect a relatively slow first iteration." + ) + # zero out the gradients before starting training + if not self._optimizer_in_bwd: + self._optimizer.zero_grad() + + training_completed = False + pbar = tqdm(total=self._total_steps, initial=self._steps_run) + for curr_epoch in range(self._epochs_run, self._total_epochs): + # Update the sampler to ensure data is correctly shuffled across epochs + # in case shuffle is True + self._sampler.set_epoch(curr_epoch) + + for _, batch in enumerate(self._dataloader): + batch = batch.to(self._device) + _, context_length = batch.shape + + # step 1. generate the trajectory using: + # - the current policy (pi_theta) + # - the current value function (V_phi) + # - the reference frozen policy model (pi_theta_0) + trajectory = self.generate_trajectory_batched(batch) + + # step 2. get the rewards for the current trajectory. these are based on: + # - the divergence between the current policy and the reference policy + # - the scores from the reward model + rewards, kl, kl_rewards = rlhf.get_rewards_ppo( + trajectory.scores, + trajectory.logprobs, + trajectory.ref_logprobs, + self._kl_coeff, + trajectory.value_seq_idxs, + ) + + # step 3. estimate the advantages using Generalized Advantage Estimation (GAE) + advantages, returns = rlhf.estimate_advantages( + trajectory.values, + rewards, + self._gamma, + self._lmbda, + masks=~trajectory.response_padding_masks, + ) + + # step 4. optimise using the PPO objective over multiple epochs + ppo_stats: List[PPOStats] = [] + for _ in range(self._ppo_epochs): + batch_idxs = torch.randperm(self.batch_size, device=self._device) + for i in range(0, self.batch_size, self._ppo_batch_size): + mini_batch_idxs = batch_idxs[i : i + self._ppo_batch_size] + + batch_ppo_stats: List[PPOStats] = [] + for j in range( + 0, self._ppo_batch_size, self._ppo_backward_batch_size + ): + backward_batch_idxs = mini_batch_idxs[ + j : j + self._ppo_backward_batch_size + ] + + batch_trajectory = Trajectory( + *map( + partial( + torch.index_select, + dim=0, + index=backward_batch_idxs, + ), + trajectory, + ) + ) + batch_ppo_stats.append( + self._ppo_step( + batch_trajectory, + advantages[backward_batch_idxs], + returns[backward_batch_idxs], + context_length, + ) + ) + del batch_trajectory + + ppo_stats.append(PPOStats(*map(sum, zip(*batch_ppo_stats)))) + + if not self._optimizer_in_bwd: + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + + self.global_step += 1 + + # step 5. profit + self._steps_run += 1 + if self._steps_run % self._log_every_n_steps == 0: + self.log_metrics( + trajectory, + PPOStats(*map(torch.stack, zip(*ppo_stats))), + kl, + kl_rewards, + ) + self.cleanup_after_step(trajectory, advantages, returns, kl, kl_rewards) + pbar.update(1) + if self._steps_run == self._total_steps: + training_completed = True + break + + # save checkpoint at current epoch + self._epochs_run += 1 + + self.save_checkpoint( + curr_epoch, is_intermediate_checkpoint=not training_completed + ) + if training_completed: + return + + def _ppo_step( + self, + trajectory: Trajectory, + advantages: torch.Tensor, + returns: torch.Tensor, + context_length: int, + ) -> PPOStats: + """ + Perform a single PPO optimisation step over a batch of trajectories and corresponding advantages and returns. + + Args: + trajectory (Trajectory): a batch of trajectories + advantages (torch.Tensor): advantages corresponding to the trajectories + returns (torch.Tensor): returns corresponding the trajectories + context_length (int): input ids sequence length + + Returns: + PPOStats: An instance of :class:`~torchtune.modules.rlhf.PPOStats`, a dataclass containing: + - loss (torch.Tensor): The total PPO loss. + - policy_loss (torch.Tensor): The policy function loss. + - value_loss (torch.Tensor): The value function loss. + - ratios (torch.Tensor): The ratio between the current and old policy probabilities. + - clipfrac (torch.Tensor): The fraction of ratios that were clipped. + - approx_policy_kls: Average estimated KL divergence between the policy before and after the optimisation step. + + """ + # estimate logprobs from the policy at the current optimisation step + pi_logits = self._policy_model( + trajectory.query_responses, + input_pos=trajectory.position_ids, + mask=trajectory.masks, + ) + pi_logits = rlhf.truncate_sequence_for_logprobs(pi_logits, context_length) + pi_logprobs = rlhf.logits_to_logprobs( + pi_logits, trajectory.query_responses[:, context_length:], self._temperature + ) + pi_logprobs[trajectory.response_padding_masks] = 1.0 + + del pi_logits + + # estimate the values from the value function at the current optimisation step + phi_values = self._value_model( + trajectory.query_responses, + input_pos=trajectory.position_ids, + mask=trajectory.masks, + ) + + phi_values = rlhf.truncate_sequence_for_logprobs( + phi_values, context_length + ).squeeze(-1) + phi_values[trajectory.value_padding_masks] = 0.0 + + # calculate ppo loss + loss, policy_loss, value_loss, ratios, clipfrac = self._loss_fn( + trajectory.logprobs, + pi_logprobs, + advantages, + trajectory.values, + phi_values, + returns, + padding_masks=~trajectory.response_padding_masks, + value_padding_masks=~trajectory.value_padding_masks, + ) + + loss /= self._gradient_accumulation_steps + loss.backward() + + with torch.no_grad(): + approx_policy_kls = ( + 0.5 * (pi_logprobs - trajectory.logprobs).pow(2) + ).mean() + + return PPOStats( + loss, + policy_loss / self._gradient_accumulation_steps, + value_loss / self._gradient_accumulation_steps, + ratios / self._gradient_accumulation_steps, + clipfrac / self._gradient_accumulation_steps, + approx_policy_kls / self._gradient_accumulation_steps, + ) + + def log_metrics( + self, + trajectory: Trajectory, + ppo_stats: PPOStats, + kl: torch.Tensor, + kl_rewards: torch.Tensor, + ) -> None: + """ + Log metrics and statistics for the current step to the metric logger. + """ + log_dict = { + "scores": trajectory.scores.mean(), + "num_stop_tokens": trajectory.response_padding_masks.any(-1).sum(), + "rlhf_reward": trajectory.scores.mean() + kl_rewards.sum(1).mean(), + "kl": kl.sum(1).mean(), + "kl_reward": kl_rewards.sum(1).mean(), + "loss": ppo_stats.loss.mean(), + "policy_loss": ppo_stats.policy_loss.mean(), + "value_loss": ppo_stats.value_loss.mean(), + "clipfrac": ppo_stats.clipfrac.mean(), + "ratios": ppo_stats.ratios.mean(), + "approx_policy_kl": ppo_stats.approx_policy_kls.mean(), + "response_lengths": trajectory.seq_lens.float().mean(), + } + if self._device.type == "cuda" and self._log_peak_memory_stats: + log_dict.update(utils.get_memory_stats(device=self._device)) + + self._metric_logger.log_dict(log_dict, step=self.global_step) + + def cleanup_after_step( + self, + trajectory: Trajectory, + advantages: torch.Tensor, + returns: torch.Tensor, + kl: torch.Tensor, + kl_rewards: torch.Tensor, + ) -> None: + """ + Cleanup tensors after each PPO step to free up memory. + """ + # there shouldn't be any floating references to the individual tensors at the this point, so gc can do its thing + for v in trajectory: + del v + del trajectory + del advantages + del returns + del kl + del kl_rewards + + def cleanup(self, **kwargs) -> None: + self._metric_logger.close() + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in config (see available configs through ``tune ls``) + - Overwritten by arguments from the command-line + """ + config.log_config(recipe_name="PPOFullFinetuneRecipeSingleDevice", cfg=cfg) + recipe = PPOFullFinetuneRecipeSingleDevice(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/tests/cache_artifacts.sh b/tests/cache_artifacts.sh index 19bf75f07e..81b50b5889 100755 --- a/tests/cache_artifacts.sh +++ b/tests/cache_artifacts.sh @@ -17,6 +17,7 @@ SMALL_MODEL_URLS=( "https://ossci-datasets.s3.amazonaws.com/torchtune/small-ckpt-meta-03082024.pt" "https://ossci-datasets.s3.amazonaws.com/torchtune/small-ckpt-hf-03082024.pt" "https://ossci-datasets.s3.amazonaws.com/torchtune/small-ckpt-tune-llama3-05052024.pt" + "https://ossci-datasets.s3.amazonaws.com/torchtune/small-ckpt-hf-reward-07122024.pt" ) FULL_MODEL_URL=("s3://pytorch-multimodal/llama2-7b-torchtune.pt") TOKENIZER_URLS=( diff --git a/tests/recipes/test_ppo_full_tunetune_single_device.py b/tests/recipes/test_ppo_full_tunetune_single_device.py new file mode 100644 index 0000000000..7113a12599 --- /dev/null +++ b/tests/recipes/test_ppo_full_tunetune_single_device.py @@ -0,0 +1,373 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os + +import runpy +import sys +from pathlib import Path + +import pytest +import torch +from tests.common import TUNE_PATH + +from tests.recipes.utils import ( + dummy_text_completion_alpaca_dataset_config, + llama2_classifier_test_config, + llama2_test_config, + write_hf_ckpt_config, +) +from tests.test_utils import ( + CKPT_MODEL_PATHS, + gen_log_file_name, + get_loss_values_from_metric_logger, +) + + +class TestPPOFullFinetuneSingleDeviceRecipe: + def _get_test_config_overrides(self): + return [ + "batch_size=4", + "forward_batch_size=4", + "ppo_batch_size=4", + "ppo_epochs=1", + "num_steps=16", + "temperature=1.0", + "gradient_accumulation_steps=1", + "device=cpu", + "dtype=fp32", + "enable_activation_checkpointing=False", + "tokenizer.path=/tmp/test-artifacts/tokenizer.model", + "tokenizer._component_=torchtune.models.llama2.llama2_tokenizer", + "seed=9", + "optimizer=torch.optim.AdamW", + "optimizer.lr=2e-5", + "log_every_n_steps=1", + ] + dummy_text_completion_alpaca_dataset_config() + + @pytest.mark.integration_test + def test_loss(self, tmpdir, monkeypatch): + + reward_ckpt = "llama2_reward_hf" + policy_ckpt = "llama2_hf" + reward_ckpt_path = Path(CKPT_MODEL_PATHS[reward_ckpt]) + policy_ckpt_path = Path(CKPT_MODEL_PATHS[policy_ckpt]) + + ckpt_dir = policy_ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + policy_tmpdir = (tmpdir / "policy").mkdir() + value_tmpdir = (tmpdir / "value").mkdir() + + write_hf_ckpt_config(ckpt_dir) + cmd_1 = f""" + tune run ppo_full_finetune_single_device \ + --config mistral/7B_full_ppo_low_memory \ + output_dir={tmpdir} \ + checkpointer._component_=torchtune.utils.FullModelHFCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{policy_ckpt_path}]\ + checkpointer.output_dir={policy_tmpdir} \ + checkpointer.model_type=LLAMA2 \ + + ref_policy_checkpointer.checkpoint_dir='{ckpt_dir}' \ + ref_policy_checkpointer.checkpoint_files=[{policy_ckpt_path}]\ + + value_checkpointer.checkpoint_dir='{ckpt_dir}' \ + value_checkpointer.checkpoint_files=[{reward_ckpt_path}]\ + value_checkpointer.output_dir={value_tmpdir} \ + + reward_checkpointer.checkpoint_dir='{ckpt_dir}' \ + reward_checkpointer.checkpoint_files=[{reward_ckpt_path}]\ + + metric_logger._component_=torchtune.utils.metric_logging.DiskLogger \ + metric_logger.filename={log_file} \ + """.split() + + model_config = llama2_test_config() + model_config = [k.replace("model.", "policy_model.") for k in model_config] + model_config += ["policy_model.intermediate_dim=null"] + + reward_and_value_model_config = llama2_classifier_test_config() + reward_and_value_model_config = [ + k.replace("model.", "reward_and_value_model.") + for k in reward_and_value_model_config + ] + reward_and_value_model_config += [ + "reward_and_value_model.intermediate_dim=null" + ] + cmd_1 = ( + cmd_1 + + self._get_test_config_overrides() + + model_config + + reward_and_value_model_config + ) + + monkeypatch.setattr(sys, "argv", cmd_1) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + loss_values = get_loss_values_from_metric_logger(log_file) + expected_loss_values = [ + 1.0403, + 0.9495, + 0.9084, + 1.0494, + 0.9609, + 0.8846, + 1.0282, + 0.9390, + 0.8915, + 1.0166, + 0.9231, + 0.9352, + ] + torch.testing.assert_close( + loss_values, expected_loss_values, atol=1e-4, rtol=1e-5 + ) + + @pytest.mark.integration_test + def test_training_state_on_resume(self, tmpdir, monkeypatch): + """Test whether the recipe state correctly saved and restored after training.""" + + reward_ckpt = "llama2_reward_hf" + policy_ckpt = "llama2_hf" + reward_ckpt_path = Path(CKPT_MODEL_PATHS[reward_ckpt]) + policy_ckpt_path = Path(CKPT_MODEL_PATHS[policy_ckpt]) + + ckpt_dir = policy_ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + policy_tmpdir = (tmpdir / "policy").mkdir() + value_tmpdir = (tmpdir / "value").mkdir() + + # Config file needed for model conversion. + # Create a second copy for training resume + write_hf_ckpt_config(ckpt_dir) + write_hf_ckpt_config(policy_tmpdir) + write_hf_ckpt_config(value_tmpdir) + # There are 4 steps in total (num_steps / batch size) + # and the dataset has 8 samples, so each epoch will be 2 batches + # a single step is a single batch update, and we checkpoint at every epoch (2 steps) + # so we're expecting an intermediate checkpoint at step 2. The idea here is to train for 4 steps, + # resume after 2, and ensure the losses for the final two steps after resuming are identical + cmd_1 = f""" + tune run ppo_full_finetune_single_device \ + --config mistral/7B_full_ppo_low_memory \ + output_dir={tmpdir} \ + checkpointer._component_=torchtune.utils.FullModelHFCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{policy_ckpt_path}]\ + checkpointer.output_dir={policy_tmpdir} \ + checkpointer.model_type=LLAMA2 \ + + ref_policy_checkpointer.checkpoint_dir='{ckpt_dir}' \ + ref_policy_checkpointer.checkpoint_files=[{policy_ckpt_path}]\ + + value_checkpointer.checkpoint_dir='{ckpt_dir}' \ + value_checkpointer.checkpoint_files=[{reward_ckpt_path}]\ + value_checkpointer.output_dir={value_tmpdir} \ + + reward_checkpointer.checkpoint_dir='{ckpt_dir}' \ + reward_checkpointer.checkpoint_files=[{reward_ckpt_path}]\ + + metric_logger._component_=torchtune.utils.metric_logging.DiskLogger \ + metric_logger.filename={log_file} \ + """.split() + + model_config = llama2_test_config() + model_config = [k.replace("model.", "policy_model.") for k in model_config] + model_config += ["policy_model.intermediate_dim=null"] + + reward_and_value_model_config = llama2_classifier_test_config() + reward_and_value_model_config = [ + k.replace("model.", "reward_and_value_model.") + for k in reward_and_value_model_config + ] + reward_and_value_model_config += [ + "reward_and_value_model.intermediate_dim=null" + ] + cmd_1 = ( + cmd_1 + + self._get_test_config_overrides() + + model_config + + reward_and_value_model_config + ) + + monkeypatch.setattr(sys, "argv", cmd_1) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + loss_values = get_loss_values_from_metric_logger(log_file) + + # Resume training at step 2 + resumed_log_dir = (tmpdir / "resumed/").mkdir() + resumed_log_file = gen_log_file_name(resumed_log_dir) + cmd_2 = f""" + tune run ppo_full_finetune_single_device \ + --config mistral/7B_full_ppo_low_memory \ + output_dir={tmpdir} \ + checkpointer._component_=torchtune.utils.FullModelHFCheckpointer \ + checkpointer.checkpoint_dir='{policy_tmpdir}' \ + checkpointer.checkpoint_files=[{os.path.join(policy_tmpdir, "hf_model_0001_0.pt")}]\ + checkpointer.recipe_checkpoint={os.path.join(policy_tmpdir, "recipe_state.pt")}\ + checkpointer.output_dir={policy_tmpdir} \ + checkpointer.model_type=LLAMA2 \ + + ref_policy_checkpointer.checkpoint_dir='{ckpt_dir}' \ + ref_policy_checkpointer.checkpoint_files=[{policy_ckpt_path}]\ + + value_checkpointer.checkpoint_dir='{value_tmpdir}' \ + value_checkpointer.checkpoint_files=[{os.path.join(value_tmpdir, "hf_model_0001_0.pt")}]\ + value_checkpointer.output_dir={value_tmpdir} \ + + reward_checkpointer.checkpoint_dir='{ckpt_dir}' \ + reward_checkpointer.checkpoint_files=[{reward_ckpt_path}]\ + + resume_from_checkpoint=True \ + metric_logger._component_=torchtune.utils.metric_logging.DiskLogger \ + metric_logger.filename={resumed_log_file} \ + """.split() + + cmd_2 = ( + cmd_2 + + self._get_test_config_overrides() + + model_config + + reward_and_value_model_config + ) + + monkeypatch.setattr(sys, "argv", cmd_2) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + resumed_loss_values = get_loss_values_from_metric_logger(resumed_log_file) + + # losses at each step are (loss, policy_loss, value_loss) + torch.testing.assert_close( + loss_values[6:], resumed_loss_values, rtol=1e-4, atol=1e-4 + ) + + @pytest.mark.integration_test + def test_training_state_on_resume_with_optimizer_in_bwd(self, tmpdir, monkeypatch): + """Test whether the recipe state correctly saves and restores optimizer state + when using ``optimizer_in_bwd``, since the optimizer checkpoint dict will include + parameters for two models. + + This is identical to ``test_training_state_on_resume``, but adds optimizer_in_bwd. + """ + + reward_ckpt = "llama2_reward_hf" + policy_ckpt = "llama2_hf" + reward_ckpt_path = Path(CKPT_MODEL_PATHS[reward_ckpt]) + policy_ckpt_path = Path(CKPT_MODEL_PATHS[policy_ckpt]) + + ckpt_dir = policy_ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + policy_tmpdir = (tmpdir / "policy").mkdir() + value_tmpdir = (tmpdir / "value").mkdir() + + # Config file needed for model conversion. + # Create a second copy for training resume + write_hf_ckpt_config(ckpt_dir) + write_hf_ckpt_config(policy_tmpdir) + write_hf_ckpt_config(value_tmpdir) + cmd_1 = f""" + tune run ppo_full_finetune_single_device \ + --config mistral/7B_full_ppo_low_memory \ + output_dir={tmpdir} \ + checkpointer._component_=torchtune.utils.FullModelHFCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{policy_ckpt_path}]\ + checkpointer.output_dir={policy_tmpdir} \ + checkpointer.model_type=LLAMA2 \ + + ref_policy_checkpointer.checkpoint_dir='{ckpt_dir}' \ + ref_policy_checkpointer.checkpoint_files=[{policy_ckpt_path}]\ + + value_checkpointer.checkpoint_dir='{ckpt_dir}' \ + value_checkpointer.checkpoint_files=[{reward_ckpt_path}]\ + value_checkpointer.output_dir={value_tmpdir} \ + + reward_checkpointer.checkpoint_dir='{ckpt_dir}' \ + reward_checkpointer.checkpoint_files=[{reward_ckpt_path}]\ + + metric_logger._component_=torchtune.utils.metric_logging.DiskLogger \ + metric_logger.filename={log_file} \ + + optimizer_in_bwd=True + """.split() + + model_config = llama2_test_config() + model_config = [k.replace("model.", "policy_model.") for k in model_config] + model_config += ["policy_model.intermediate_dim=null"] + + reward_and_value_model_config = llama2_classifier_test_config() + reward_and_value_model_config = [ + k.replace("model.", "reward_and_value_model.") + for k in reward_and_value_model_config + ] + reward_and_value_model_config += [ + "reward_and_value_model.intermediate_dim=null" + ] + cmd_1 = ( + cmd_1 + + self._get_test_config_overrides() + + model_config + + reward_and_value_model_config + ) + + monkeypatch.setattr(sys, "argv", cmd_1) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + loss_values = get_loss_values_from_metric_logger(log_file) + + # Resume training at step 2 + resumed_log_dir = (tmpdir / "resumed/").mkdir() + resumed_log_file = gen_log_file_name(resumed_log_dir) + cmd_2 = f""" + tune run ppo_full_finetune_single_device \ + --config mistral/7B_full_ppo_low_memory \ + output_dir={tmpdir} \ + checkpointer._component_=torchtune.utils.FullModelHFCheckpointer \ + checkpointer.checkpoint_dir='{policy_tmpdir}' \ + checkpointer.checkpoint_files=[{os.path.join(policy_tmpdir, "hf_model_0001_0.pt")}]\ + checkpointer.recipe_checkpoint={os.path.join(policy_tmpdir, "recipe_state.pt")}\ + checkpointer.output_dir={policy_tmpdir} \ + checkpointer.model_type=LLAMA2 \ + + ref_policy_checkpointer.checkpoint_dir='{ckpt_dir}' \ + ref_policy_checkpointer.checkpoint_files=[{policy_ckpt_path}]\ + + value_checkpointer.checkpoint_dir='{value_tmpdir}' \ + value_checkpointer.checkpoint_files=[{os.path.join(value_tmpdir, "hf_model_0001_0.pt")}]\ + value_checkpointer.output_dir={value_tmpdir} \ + + reward_checkpointer.checkpoint_dir='{ckpt_dir}' \ + reward_checkpointer.checkpoint_files=[{reward_ckpt_path}]\ + + resume_from_checkpoint=True \ + metric_logger._component_=torchtune.utils.metric_logging.DiskLogger \ + metric_logger.filename={resumed_log_file} \ + + optimizer_in_bwd=True + """.split() + + cmd_2 = ( + cmd_2 + + self._get_test_config_overrides() + + model_config + + reward_and_value_model_config + ) + + monkeypatch.setattr(sys, "argv", cmd_2) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + resumed_loss_values = get_loss_values_from_metric_logger(resumed_log_file) + + # losses at each step are (loss, policy_loss, value_loss) + torch.testing.assert_close( + loss_values[6:], resumed_loss_values, rtol=1e-4, atol=1e-4 + ) diff --git a/tests/recipes/utils.py b/tests/recipes/utils.py index 66297984fb..a1c820d5b7 100644 --- a/tests/recipes/utils.py +++ b/tests/recipes/utils.py @@ -61,6 +61,24 @@ def dummy_alpaca_dataset_config(): return out +def dummy_text_completion_alpaca_dataset_config(): + """ + Constructs a minimal text-completion-style dataset from ``alpaca_tiny.json``. + This is used for testing PPO fine-tuning. + """ + data_files = os.path.join(get_assets_path(), "alpaca_tiny.json") + out = [ + "dataset._component_=torchtune.datasets.text_completion_dataset", + "dataset.source='json'", + f"dataset.data_files={data_files}", + "dataset.column='instruction'", + "dataset.split='train[:10%]'", # 10% of the dataset gets us 8 batches + "dataset.max_seq_len=64", + "dataset.add_eos=False", + ] + return out + + def llama2_test_config() -> List[str]: return [ "model._component_=torchtune.models.llama2.llama2", @@ -74,6 +92,20 @@ def llama2_test_config() -> List[str]: ] +def llama2_classifier_test_config() -> List[str]: + return [ + "model._component_=torchtune.models.llama2.llama2_classifier", + "model.num_classes=1", + "model.vocab_size=32_000", + "model.num_layers=4", + "model.num_heads=16", + "model.embed_dim=256", + "model.max_seq_len=2048", + "model.norm_eps=1e-5", + "model.num_kv_heads=8", + ] + + def llama3_test_config() -> List[str]: return [ "model._component_=torchtune.models.llama3.llama3", diff --git a/tests/test_utils.py b/tests/test_utils.py index 35ea7cef22..db996533ee 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -29,6 +29,7 @@ "llama2_tune": "/tmp/test-artifacts/small-ckpt-tune-03082024.pt", "llama2_meta": "/tmp/test-artifacts/small-ckpt-meta-03082024.pt", "llama2_hf": "/tmp/test-artifacts/small-ckpt-hf-03082024.pt", + "llama2_reward_hf": "/tmp/test-artifacts/small-ckpt-hf-reward-07122024.pt", "llama3_tune": "/tmp/test-artifacts/small-ckpt-tune-llama3-05052024.pt", "llama2_7b": "/tmp/test-artifacts/llama2-7b-torchtune.pt", } diff --git a/tests/torchtune/modules/loss/__init__.py b/tests/torchtune/modules/loss/__init__.py new file mode 100644 index 0000000000..2e41cd717f --- /dev/null +++ b/tests/torchtune/modules/loss/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/torchtune/modules/loss/test_ppo_loss.py b/tests/torchtune/modules/loss/test_ppo_loss.py new file mode 100644 index 0000000000..6445da3120 --- /dev/null +++ b/tests/torchtune/modules/loss/test_ppo_loss.py @@ -0,0 +1,138 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from torchtune.modules.loss import PPOLoss + + +@pytest.fixture(autouse=True) +def random(): + torch.manual_seed(16) + + +class TestPPOLoss: + @pytest.fixture + def loss_fn(self): + return PPOLoss( + value_clip_range=0.2, + value_coeff=0.1, + epsilon=0.2, + ) + + def test_policy_loss_clipped_for_high_logprobs(self, loss_fn): + # fixed old policy logprobs, advantages, returns + pi_old_logprobs = torch.tensor([0.5, 0.8, 1.2]) + advantages = torch.tensor([1.0, 1.0, 1.0]) + values = torch.tensor([1.0, 1.0, 1.0]) + returns = torch.tensor([1.0, 1.0, 1.0]) + + pi_logprobs_high = torch.tensor([1.5, 1.8, 2.2]) + # ratio will be [e, e, e] + # clipped ratio becomes [1.2, 1.2, 1.2] (1+epsilon) + # objective becomes max(-e, -1.2) since advantages is 1 + expected_loss = torch.tensor(-1.2) + expected_ratios = torch.exp(torch.ones((3))) + + _, policy_loss, _, ratios, _ = loss_fn( + pi_old_logprobs, pi_logprobs_high, advantages, values, values, returns + ) + + torch.testing.assert_close( + policy_loss.mean(), expected_loss, atol=1e-4, rtol=1e6 + ) + torch.testing.assert_close(ratios, expected_ratios.mean(), atol=1e-4, rtol=1e6) + + def test_policy_loss_clipped_for_low_logprobs(self, loss_fn): + # fixed old policy logprobs, advantages, returns + pi_old_logprobs = torch.tensor([0.5, 0.8, 1.2]) + advantages = torch.tensor([1.0, 1.0, 1.0]) + values = torch.tensor([1.0, 1.0, 1.0]) + returns = torch.tensor([1.0, 1.0, 1.0]) + + pi_logprobs_low = torch.tensor([-0.5, -0.2, 0.2]) + # ratio will be [1/e, 1/e, 1/e] (~0.367) + # clipped ratio becomes [0.8, 0.8, 0.8] (1-epsilon) + # objective becomes max(1/e, 0.8) since advantages is 1 + expected_loss = torch.tensor(0.8) + expected_ratios = 1 / torch.exp(torch.ones((3))) + + _, policy_loss, _, ratios, _ = loss_fn( + pi_old_logprobs, pi_logprobs_low, advantages, values, values, returns + ) + + torch.testing.assert_close( + policy_loss.mean(), expected_loss, atol=1e-4, rtol=1e6 + ) + torch.testing.assert_close(ratios, expected_ratios.mean(), atol=1e-4, rtol=1e6) + + def test_policy_loss_not_clipped(self, loss_fn): + # fixed old policy logprobs, advantages, returns + pi_old_logprobs = torch.tensor([0.5, 0.8, 1.2]) + advantages = torch.tensor([1.0, 1.0, 1.0]) + values = torch.tensor([1.0, 1.0, 1.0]) + returns = torch.tensor([1.0, 1.0, 1.0]) + + pi_logprobs_unclipped = torch.tensor([0.6, 0.9, 1.3]) + # ratio will be [e^0.1, e^0.1, e^0.1] (~1.1) + # ratio is not clipped since it is within [1-epsilon, 1+epsilon], [0.8, 1.2] + expected_loss = torch.tensor(0.1).exp() + expected_ratios = torch.exp(torch.ones(3) * 0.1) + + _, policy_loss, _, ratios, _ = loss_fn( + pi_old_logprobs, pi_logprobs_unclipped, advantages, values, values, returns + ) + + torch.testing.assert_close( + policy_loss.mean(), expected_loss, atol=1e-4, rtol=1e6 + ) + torch.testing.assert_close(ratios, expected_ratios.mean(), atol=1e-4, rtol=1e6) + + def test_policy_loss_lower_for_higher_advantages(self, loss_fn): + pi_logprobs = torch.tensor([-0.5, -0.8, -1.2]) + + advantages_high = torch.tensor([1.0, 2.0, 3.0]) + advantages_low = torch.tensor([0.5, 1.0, 1.5]) + values = torch.tensor([1.0, 1.0, 1.0]) + returns = torch.tensor([1.0, 1.0, 1.0]) + + _, policy_loss_low, *_ = loss_fn( + pi_logprobs, pi_logprobs, advantages_high, values, values, returns + ) + _, policy_loss_high, *_ = loss_fn( + pi_logprobs, pi_logprobs, advantages_low, values, values, returns + ) + + assert policy_loss_low.mean() < policy_loss_high.mean() + + def test_value_loss_lower_for_values_similar_to_return(self, loss_fn): + # fix pi_logrobs, pi_old_logprobs, returns, advantages + pi_logprobs = torch.tensor([-0.5, -0.8, -1.2]) + returns = torch.tensor([1.0, 1.0, 1.0]) + advantages = torch.tensor([1.0, 1.0, 1.0]) + + # values estimates are similar to returns + values_similar = torch.tensor([0.9, 1.0, 1.1]) + # value estimates are less similar to returns + values_less_similar = torch.tensor([0.5, 1.5, 2.0]) + + _, _, value_loss_lower, *_ = loss_fn( + pi_logprobs, + pi_logprobs, + advantages, + values_similar, + values_similar, + returns, + ) + _, _, value_loss_higher, *_ = loss_fn( + pi_logprobs, + pi_logprobs, + advantages, + values_similar, + values_less_similar, + returns, + ) + assert value_loss_lower.mean() < value_loss_higher.mean() diff --git a/tests/torchtune/modules/low_precision/test_nf4_linear.py b/tests/torchtune/modules/low_precision/test_nf4_linear.py index 5408561f12..c3b6320c66 100644 --- a/tests/torchtune/modules/low_precision/test_nf4_linear.py +++ b/tests/torchtune/modules/low_precision/test_nf4_linear.py @@ -4,114 +4,120 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - -import bitsandbytes as bnb -import pytest -import torch -from torchao.dtypes.nf4tensor import NF4Tensor -from torchtune.modules.low_precision import FrozenNF4Linear -from torchtune.utils.seed import set_seed - - -@pytest.fixture(autouse=True) -def random(): - set_seed(31) - - -def _build_bnb_linear(input_weight): - """ - Builds a bnb.nn.LinearNF4 from a given input weight - """ - param = bnb.nn.Params4bit(input_weight, requires_grad=False, quant_type="nf4") - bnb_linear = bnb.nn.LinearNF4( - input_weight.size(0), input_weight.size(1), bias=False - ) - bnb_linear.weight = param - bnb_linear.cuda() - return bnb_linear - - -class TestNF4Linear: - """ - Class for testing our NF4Linear implementation. - """ - - def test_bias_unsupported(self): - with pytest.raises(RuntimeError, match="does not currently support biases"): - _ = FrozenNF4Linear(1, 1, bias=True) - - @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) - def test_parameters(self, dtype): - nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype) - params = list(nf4_linear.parameters()) - assert len(params) == 1 - assert isinstance(params[0], NF4Tensor) - - @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) - def test_state_dict(self, dtype): - nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype) - state_dict = nf4_linear.state_dict() - assert len(state_dict) == 1 - assert isinstance(state_dict["weight"], NF4Tensor) - - @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) - def test_output_dtype(self, dtype): - # Test to ensure W4 A16 produces A16 / W4A32 produces A32 - nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype) - inp = torch.randn(2, 512, dtype=dtype, requires_grad=True) - out = nf4_linear(inp) - assert out.dtype == dtype - - @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) - def test_backward_dtype(self, dtype): - # Test to ensure backward pass gives activation a bf16 gradient and no gradient - # to the linear's weight, as it is frozen. - nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype) - inp = torch.randn(2, 512, dtype=dtype, requires_grad=True) - nf4_linear(inp).sum().backward() - assert inp.grad is not None and inp.grad.dtype == dtype - assert nf4_linear.weight.grad is None - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") - @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) - def test_nf4_reconstruction_vs_bnb(self, dtype): - """ - Ensures a BNB NF4 linear and our FrozenNF4Linear have low error when - reconstructing the respective original weights. - """ - dim = 512 - nf4_linear = FrozenNF4Linear(dim, dim, device="cuda", dtype=dtype) - orig_weight = nf4_linear.weight.get_original_weight().clone().detach() - bnb_nf4_linear = _build_bnb_linear(input_weight=orig_weight) - - # From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65 - bnb_reconstruction = bnb_nf4_linear( - torch.eye(dim, dim, dtype=dtype, device="cuda") - ) - # Ensure nf4_linear and bnb reconstructions are close to each other. - assert torch.allclose( - bnb_reconstruction.T, nf4_linear.weight.get_original_weight(), 1e-2 - ) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") - @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) - def test_nf4_bnb_linear(self, dtype): - """ - This test ensures that nf4_linear is "no worse" than BNB by ensuring the - error compared to a bf16 linear is not more than BNB's implementation. - """ - dim = 512 - nf4_linear = FrozenNF4Linear(dim, dim, device="cuda", dtype=dtype) - orig_weight = nf4_linear.weight.get_original_weight().clone().detach() - bnb_nf4_linear = _build_bnb_linear(input_weight=orig_weight) - bf16_linear = torch.nn.Linear(dim, dim, device="cuda", dtype=dtype) - - inp = torch.randn(2, 512, dtype=dtype, device="cuda") - - out_nf4 = nf4_linear(inp) - out_bnb = bnb_nf4_linear(inp) - out_ref = bf16_linear(inp) - - err_bnb = out_bnb - out_ref - err_native = out_nf4 - out_ref - assert torch.allclose(err_bnb, err_native, 1.0e-2, 1.0e-2) +# # Copyright (c) Meta Platforms, Inc. and affiliates. +# # All rights reserved. +# # +# # This source code is licensed under the BSD-style license found in the +# # LICENSE file in the root directory of this source tree. + + +# import bitsandbytes as bnb +# import pytest +# import torch +# from torchao.dtypes.nf4tensor import NF4Tensor +# from torchtune.modules.low_precision import FrozenNF4Linear +# from torchtune.utils.seed import set_seed + + +# @pytest.fixture(autouse=True) +# def random(): +# set_seed(31) + + +# def _build_bnb_linear(input_weight): +# """ +# Builds a bnb.nn.LinearNF4 from a given input weight +# """ +# param = bnb.nn.Params4bit(input_weight, requires_grad=False, quant_type="nf4") +# bnb_linear = bnb.nn.LinearNF4( +# input_weight.size(0), input_weight.size(1), bias=False +# ) +# bnb_linear.weight = param +# bnb_linear.cuda() +# return bnb_linear + + +# class TestNF4Linear: +# """ +# Class for testing our NF4Linear implementation. +# """ + +# def test_bias_unsupported(self): +# with pytest.raises(RuntimeError, match="does not currently support biases"): +# _ = FrozenNF4Linear(1, 1, bias=True) + +# @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +# def test_parameters(self, dtype): +# nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype) +# params = list(nf4_linear.parameters()) +# assert len(params) == 1 +# assert isinstance(params[0], NF4Tensor) + +# @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +# def test_state_dict(self, dtype): +# nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype) +# state_dict = nf4_linear.state_dict() +# assert len(state_dict) == 1 +# assert isinstance(state_dict["weight"], NF4Tensor) + +# @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +# def test_output_dtype(self, dtype): +# # Test to ensure W4 A16 produces A16 / W4A32 produces A32 +# nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype) +# inp = torch.randn(2, 512, dtype=dtype, requires_grad=True) +# out = nf4_linear(inp) +# assert out.dtype == dtype + +# @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +# def test_backward_dtype(self, dtype): +# # Test to ensure backward pass gives activation a bf16 gradient and no gradient +# # to the linear's weight, as it is frozen. +# nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype) +# inp = torch.randn(2, 512, dtype=dtype, requires_grad=True) +# nf4_linear(inp).sum().backward() +# assert inp.grad is not None and inp.grad.dtype == dtype +# assert nf4_linear.weight.grad is None + +# @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") +# @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +# def test_nf4_reconstruction_vs_bnb(self, dtype): +# """ +# Ensures a BNB NF4 linear and our FrozenNF4Linear have low error when +# reconstructing the respective original weights. +# """ +# dim = 512 +# nf4_linear = FrozenNF4Linear(dim, dim, device="cuda", dtype=dtype) +# orig_weight = nf4_linear.weight.get_original_weight().clone().detach() +# bnb_nf4_linear = _build_bnb_linear(input_weight=orig_weight) + +# # From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65 +# bnb_reconstruction = bnb_nf4_linear( +# torch.eye(dim, dim, dtype=dtype, device="cuda") +# ) +# # Ensure nf4_linear and bnb reconstructions are close to each other. +# assert torch.allclose( +# bnb_reconstruction.T, nf4_linear.weight.get_original_weight(), 1e-2 +# ) + +# @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") +# @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +# def test_nf4_bnb_linear(self, dtype): +# """ +# This test ensures that nf4_linear is "no worse" than BNB by ensuring the +# error compared to a bf16 linear is not more than BNB's implementation. +# """ +# dim = 512 +# nf4_linear = FrozenNF4Linear(dim, dim, device="cuda", dtype=dtype) +# orig_weight = nf4_linear.weight.get_original_weight().clone().detach() +# bnb_nf4_linear = _build_bnb_linear(input_weight=orig_weight) +# bf16_linear = torch.nn.Linear(dim, dim, device="cuda", dtype=dtype) + +# inp = torch.randn(2, 512, dtype=dtype, device="cuda") + +# out_nf4 = nf4_linear(inp) +# out_bnb = bnb_nf4_linear(inp) +# out_ref = bf16_linear(inp) + +# err_bnb = out_bnb - out_ref +# err_native = out_nf4 - out_ref +# assert torch.allclose(err_bnb, err_native, 1.0e-2, 1.0e-2) diff --git a/tests/torchtune/modules/rlhf/__init__.py b/tests/torchtune/modules/rlhf/__init__.py new file mode 100644 index 0000000000..2e41cd717f --- /dev/null +++ b/tests/torchtune/modules/rlhf/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/torchtune/modules/rlhf/test_collate.py b/tests/torchtune/modules/rlhf/test_collate.py new file mode 100644 index 0000000000..1fecff7180 --- /dev/null +++ b/tests/torchtune/modules/rlhf/test_collate.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +import torch + +from torchtune.modules.rlhf import left_padded_collate + + +class TestLeftPaddedCollate: + def test_left_padded_collate(self): + """ + Tests that input sequences are left-padded to the max seq len. + """ + padding_idx = -8 + tokens = [ + { + "tokens": [ + 1, + 2, + ], + }, + { + "tokens": [3], + }, + { + "tokens": [4, 5, 6, 7], + }, + ] + padded_tokens = left_padded_collate(batch=tokens, padding_idx=padding_idx) + + expected_padded_tokens = torch.tensor( + [ + [padding_idx, padding_idx, 1, 2], + [padding_idx, padding_idx, padding_idx, 3], + [4, 5, 6, 7], + ] + ) + torch.testing.assert_close(padded_tokens, expected_padded_tokens) diff --git a/tests/torchtune/modules/rlhf/test_generation.py b/tests/torchtune/modules/rlhf/test_generation.py new file mode 100644 index 0000000000..2613a4f6b8 --- /dev/null +++ b/tests/torchtune/modules/rlhf/test_generation.py @@ -0,0 +1,394 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +import torch +from tests.test_utils import fixed_init_model +from torchtune.models.llama2 import llama2 +from torchtune.modules import rlhf +from torchtune.utils._generation import sample + + +class TestGenerateNextTokenWithLogits: + @pytest.fixture + def generation_model(self): + model = llama2( + vocab_size=4_000, + embed_dim=128, + num_layers=2, + num_heads=4, + num_kv_heads=4, + max_seq_len=2048, + ) + fixed_init_model(model) + model.eval() + return model + + def test_generate_next_token_with_logits(self, generation_model): + + inputs = torch.tensor( + [ + [3, 4, 5], + [6, 7, 8], + [9, 10, 11], + ] + ) + + input_pos = torch.tensor( + [ + [0, 1, 2], + [0, 1, 2], + [0, 1, 2], + ] + ) + + torch.manual_seed(42) + logits, generation = rlhf.generate_next_token_with_logits( + generation_model, input_pos, inputs + ) + + torch.manual_seed(42) + expected_logits = generation_model(inputs, input_pos=input_pos) + expected_generation = sample(logits[:, -1], temperature=1.0, top_k=None) + + torch.testing.assert_close(logits, expected_logits, atol=1e-4, rtol=1e-5) + torch.testing.assert_close(generation, expected_generation, atol=0, rtol=0) + + +class TestGenerate: + """ + Test class for text generation functionality in :func:`~torchtune.modules.rlhf.generate`. + See `torchtune.tests.utils.test_generation` for context. + """ + + @pytest.fixture + def generation_model(self): + model = llama2( + vocab_size=4_000, + embed_dim=128, + num_layers=2, + num_heads=4, + num_kv_heads=4, + max_seq_len=2048, + ) + fixed_init_model(model) + model.eval() + return model + + @pytest.fixture + def prompt_tokens(self): + """ + Pytest fixture to create a list of prompt tokens for testing. + """ + return torch.arange(2, 10) + + @pytest.fixture + def prompt_tokens_batched(self): + """ + Pytest fixture to create a list of batched prompt tokens for testing. + """ + return torch.arange(2, 10).repeat(3, 1) + + @pytest.fixture + def prompt_tokens_padded(self): + """ + Pytest fixture to create a list of left-padded prompt tokens for testing. + """ + return torch.cat([torch.tensor([0, 0]), torch.arange(2, 10)]) + + @pytest.fixture + def prompt_tokens_batched_left_padded(self): + """ + Pytest fixture to create a list of left-padded batched prompt tokens for testing. + """ + return torch.cat([torch.tensor([0, 0]), torch.arange(2, 10)]).repeat(3, 1) + + def test_reproducability_with_and_without_padding_batched( + self, + generation_model, + prompt_tokens_batched_left_padded, + prompt_tokens_batched, + ): + """ + Test to check if the `generate` function produces the same output for inputs that are left padded + and for the same inputs that are not left padded, for a batch of inputs with varying sequence lengths. + """ + temperature = 0.6 + top_k = 100 + + torch.manual_seed(42) + outputs, _ = rlhf.generate_with_logits( + model=generation_model, + prompt=prompt_tokens_batched_left_padded, + max_generated_tokens=10, + temperature=temperature, + top_k=top_k, + ) + + torch.manual_seed(42) + expected_outputs, _ = rlhf.generate_with_logits( + model=generation_model, + prompt=prompt_tokens_batched, + max_generated_tokens=10, + temperature=temperature, + top_k=top_k, + ) + + torch.testing.assert_close(outputs[:, 2:], expected_outputs, atol=0, rtol=0) + + def test_reproducability_with_and_without_padding( + self, generation_model, prompt_tokens, prompt_tokens_padded + ): + """ + Test to check if the `generate` function produces the same output for inputs that are left padded + and for the same inputs that are not left padded. + """ + temperature = 0.6 + top_k = 100 + + torch.manual_seed(42) + + outputs, _ = rlhf.generate_with_logits( + model=generation_model, + prompt=prompt_tokens_padded, + max_generated_tokens=10, + temperature=temperature, + top_k=top_k, + ) + + torch.manual_seed(42) + expected_outputs, _ = rlhf.generate_with_logits( + model=generation_model, + prompt=prompt_tokens, + max_generated_tokens=10, + temperature=temperature, + top_k=top_k, + ) + + torch.testing.assert_close(outputs[:, 2:], expected_outputs, atol=0, rtol=0) + + +class TestGetCausalMask: + @pytest.fixture + def left_padded_prompt_tokens(self): + """ + Pytest fixture to create a list of left-padded prompt tokens for testing. + """ + return torch.cat([torch.tensor([0, 0]), torch.arange(2, 6)]).unsqueeze(0) + + @pytest.fixture + def left_padded_prompt_tokens_batched(self): + """ + Pytest fixture to create a list of left-padded batched prompt tokens for testing. + """ + return torch.tensor( + [[0, 0, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5], [0, 0, 0, 0, 0, 1]] + ) + + @pytest.fixture + def right_padded_prompt_tokens(self): + """ + Pytest fixture to create a list of right-padded prompt tokens for testing. + """ + return torch.cat([torch.arange(2, 6), torch.tensor([0, 0])]).unsqueeze(0) + + @pytest.fixture + def right_padded_prompt_tokens_batched(self): + """ + Pytest fixture to create a list of right-padded batched prompt tokens for testing. + """ + return torch.tensor( + [[1, 2, 3, 4, 5, 0], [1, 2, 0, 0, 0, 0], [1, 2, 3, 4, 5, 6]] + ) + + @pytest.fixture + def mixed_padded_prompt_tokens(self): + """ + Pytest fixture to create a list of mixed padded prompt tokens for testing. + """ + return torch.cat( + [torch.tensor([0, 0]), torch.arange(2, 6), torch.tensor([0, 0])] + ).unsqueeze(0) + + @pytest.fixture + def mixed_padded_prompt_tokens_batched(self): + """ + Pytest fixture to create a list of mixed padded batched prompt tokens for testing. + """ + return torch.tensor( + [[0, 0, 1, 2, 0, 0], [0, 1, 2, 3, 4, 0], [0, 0, 0, 1, 0, 0]] + ) + + def test_get_causal_mask_for_left_padded_inputs(self, left_padded_prompt_tokens): + """ + Test to check if the `get_causal_mask` function produces the right output for left-padded prompts. + """ + expected_casual_mask = torch.tensor( + [ + [True, False, False, False, False, False], + [False, True, False, False, False, False], + [False, False, True, False, False, False], + [False, False, True, True, False, False], + [False, False, True, True, True, False], + [False, False, True, True, True, True], + ] + ).unsqueeze(0) + + causal_mask = rlhf.get_causal_mask(left_padded_prompt_tokens != 0) + torch.testing.assert_close(causal_mask, expected_casual_mask, atol=0, rtol=0) + + def test_get_causal_mask_for_left_padded_inputs_batched( + self, left_padded_prompt_tokens_batched + ): + """ + Test to check if the `get_causal_mask` function produces the right output for left-padded batched prompts. + """ + expected_causal_mask = torch.tensor( + [ + [ + [True, False, False, False, False, False], + [False, True, False, False, False, False], + [False, False, True, False, False, False], + [False, False, False, True, False, False], + [False, False, False, True, True, False], + [False, False, False, True, True, True], + ], + [ + [True, False, False, False, False, False], + [False, True, False, False, False, False], + [False, True, True, False, False, False], + [False, True, True, True, False, False], + [False, True, True, True, True, False], + [False, True, True, True, True, True], + ], + [ + [True, False, False, False, False, False], + [False, True, False, False, False, False], + [False, False, True, False, False, False], + [False, False, False, True, False, False], + [False, False, False, False, True, False], + [False, False, False, False, False, True], + ], + ] + ) + + causal_mask = rlhf.get_causal_mask(left_padded_prompt_tokens_batched != 0) + torch.testing.assert_close(causal_mask, expected_causal_mask, atol=0, rtol=0) + + def test_get_causal_mask_for_right_padded_inputs(self, right_padded_prompt_tokens): + """ + Test to check if the `get_causal_mask` function produces the right output for right-padded prompts. + """ + expected_causal_mask = torch.tensor( + [ + [True, False, False, False, False, False], + [True, True, False, False, False, False], + [True, True, True, False, False, False], + [True, True, True, True, False, False], + [False, False, False, False, True, False], + [False, False, False, False, False, True], + ] + ).unsqueeze(0) + + causal_mask = rlhf.get_causal_mask(right_padded_prompt_tokens != 0) + torch.testing.assert_close(causal_mask, expected_causal_mask, atol=0, rtol=0) + + def test_get_causal_mask_for_right_padded_inputs_batched( + self, right_padded_prompt_tokens_batched + ): + """ + Test to check if the `get_causal_mask` function produces the right output for right-padded batched prompts. + """ + expected_causal_mask = torch.tensor( + [ + [ + [True, False, False, False, False, False], + [True, True, False, False, False, False], + [True, True, True, False, False, False], + [True, True, True, True, False, False], + [True, True, True, True, True, False], + [False, False, False, False, False, True], + ], + [ + [True, False, False, False, False, False], + [True, True, False, False, False, False], + [False, False, True, False, False, False], + [False, False, False, True, False, False], + [False, False, False, False, True, False], + [False, False, False, False, False, True], + ], + [ + [True, False, False, False, False, False], + [True, True, False, False, False, False], + [True, True, True, False, False, False], + [True, True, True, True, False, False], + [True, True, True, True, True, False], + [True, True, True, True, True, True], + ], + ] + ) + + causal_mask = rlhf.get_causal_mask(right_padded_prompt_tokens_batched != 0) + torch.testing.assert_close(causal_mask, expected_causal_mask, atol=0, rtol=0) + + def test_get_causal_mask_for_mixed_padding_inputs(self, mixed_padded_prompt_tokens): + """ + Test to check if the `get_causal_mask` function produces the right output for mixed padded prompts. + """ + expected_causal_mask = torch.tensor( + [ + [True, False, False, False, False, False, False, False], + [False, True, False, False, False, False, False, False], + [False, False, True, False, False, False, False, False], + [False, False, True, True, False, False, False, False], + [False, False, True, True, True, False, False, False], + [False, False, True, True, True, True, False, False], + [False, False, False, False, False, False, True, False], + [False, False, False, False, False, False, False, True], + ] + ).unsqueeze(0) + + causal_mask = rlhf.get_causal_mask(mixed_padded_prompt_tokens != 0) + torch.testing.assert_close(causal_mask, expected_causal_mask, atol=0, rtol=0) + + def test_get_causal_mask_for_mixed_padded_inputs_batched( + self, mixed_padded_prompt_tokens_batched + ): + """ + Test to check if the `get_causal_mask` function produces the right output for mixed-padded batched prompts. + """ + expected_causal_mask = torch.tensor( + [ + [ + [True, False, False, False, False, False], + [False, True, False, False, False, False], + [False, False, True, False, False, False], + [False, False, True, True, False, False], + [False, False, False, False, True, False], + [False, False, False, False, False, True], + ], + [ + [True, False, False, False, False, False], + [False, True, False, False, False, False], + [False, True, True, False, False, False], + [False, True, True, True, False, False], + [False, True, True, True, True, False], + [False, False, False, False, False, True], + ], + [ + [True, False, False, False, False, False], + [False, True, False, False, False, False], + [False, False, True, False, False, False], + [False, False, False, True, False, False], + [False, False, False, False, True, False], + [False, False, False, False, False, True], + ], + ] + ) + + causal_mask = rlhf.get_causal_mask(mixed_padded_prompt_tokens_batched != 0) + torch.testing.assert_close(causal_mask, expected_causal_mask, atol=0, rtol=0) diff --git a/tests/torchtune/modules/rlhf/test_rewards.py b/tests/torchtune/modules/rlhf/test_rewards.py new file mode 100644 index 0000000000..0e8ec998fa --- /dev/null +++ b/tests/torchtune/modules/rlhf/test_rewards.py @@ -0,0 +1,222 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torchtune.modules import rlhf + + +class TestGetRewards: + def test_get_rewards(self): + scores = torch.tensor([1.0, 2.0, 3.0]) + logprobs = torch.tensor( + [ + [0.1, 0.2, 0.3], + [0.4, 0.5, 0.6], + [0.6, 0.7, 0.8], + ] + ) + ref_logprobs = torch.tensor( + [ + [0.2, 0.3, 0.4], + [0.6, 0.7, 0.8], + [0.9, 1.0, 1.1], + ] + ) + kl_controller_value = 0.5 + + # expected kl is logprobs - ref_logprobs + expected_kl = torch.tensor( + [ + [-0.1, -0.1, -0.1], + [-0.2, -0.2, -0.2], + [-0.3, -0.3, -0.3], + ] + ) + + # expected kl_rewards is -kl_controller_value * kl + expected_kl_rewards = torch.tensor( + [ + [0.05, 0.05, 0.05], + [0.1, 0.1, 0.1], + [0.15, 0.15, 0.15], + ] + ) + + # expected rewards is kl_rewards[:, -1] + scores + expected_rewards = torch.tensor( + [ + [0.05, 0.05, 1.05], + [0.1, 0.1, 2.1], + [0.15, 0.15, 3.15], + ] + ) + + rewards, kl, kl_rewards = rlhf.get_rewards_ppo( + scores, logprobs, ref_logprobs, kl_controller_value + ) + + torch.testing.assert_close(kl, expected_kl, rtol=1e-4, atol=1e-4) + torch.testing.assert_close( + kl_rewards, expected_kl_rewards, rtol=1e-4, atol=1e-4 + ) + torch.testing.assert_close(rewards, expected_rewards, rtol=1e-4, atol=1e-4) + + +class TestWhiten: + def test_whiten_with_shift_mean(self): + x = torch.normal(1, 2, size=(100, 100)) + + expected_mean, expected_var = x.mean(), x.var() # should be ~1.0, ~4.0 + expected = (x - expected_mean) / (torch.sqrt(expected_var) + 1e-8) + expected += expected_mean + output = rlhf.whiten(x, shift_mean=True) + + torch.testing.assert_close(output, expected, rtol=1e-4, atol=1e-4) + + def test_whiten_without_shift_mean(self): + x = torch.normal(1, 2, size=(100, 100)) + + expected_mean, expected_var = x.mean(), x.var() # should be ~1.0, ~4.0 + expected = (x - expected_mean) / (torch.sqrt(expected_var) + 1e-8) + output = rlhf.whiten(x, shift_mean=False) + + torch.testing.assert_close(output, expected, rtol=1e-4, atol=1e-4) + + def test_masked_whiten(self): + x_mean_1 = torch.normal(1, 2, size=(50, 100)) + x_mean_2 = torch.normal(2, 1, size=(50, 100)) + x = torch.cat([x_mean_1, x_mean_2], dim=0) + mask = torch.ones_like(x, dtype=torch.bool) + mask[:50] = False + + expected_mean, expected_var = ( + x_mean_2.mean(), + x_mean_2.var(), + ) # should be ~2.0, ~1.0 + expected = (x - expected_mean) / (torch.sqrt(expected_var) + 1e-8) + expected += expected_mean + + output = rlhf.whiten(x, mask=mask) + + torch.testing.assert_close(output, expected, rtol=1e-4, atol=1e-4) + + +class TestMaskedMean: + def test_masked_single_batch_mean(self): + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + mask = torch.tensor([True, True, True, False, False]) + + expected_mean = torch.tensor(2.0) + output = rlhf.masked_mean(x, mask) + + torch.testing.assert_close(output, expected_mean, rtol=1e-4, atol=1e-4) + + def test_masked_multi_batch_mean(self): + x = torch.tensor( + [ + [1.0, 2.0, 3.0, 4.0, 5.0], + [2.0, 3.0, 4.0, 5.0, 6.0], + ] + ) + mask = torch.tensor( + [[True, True, True, False, False], [False, False, True, True, True]] + ) + + expected_means = torch.tensor([2.0, 5.0]) + output = rlhf.masked_mean(x, mask, dim=1) + + torch.testing.assert_close(output, expected_means, rtol=1e-4, atol=1e-4) + + +class TestMaskedVar: + def test_masked_var(self): + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + mask = torch.tensor([True, True, True, False, False]) + + expected_var = torch.tensor(1.0) + output = rlhf.masked_var(x, mask) + + torch.testing.assert_close(output, expected_var, rtol=1e-4, atol=1e-4) + + +class TestEstimateAdvantages: + def test_estimate_returns(self): + values = torch.tensor([[0, 0, 0, 1]]) + rewards = torch.tensor([[0, 0, 0, 1]]) + gamma = 0.9 + lmbda = 0.95 + + final_reward = 1.0 + expected_returns = torch.tensor( + [ + [ + final_reward * gamma * gamma * gamma * lmbda * lmbda, + final_reward * gamma * gamma * lmbda, + final_reward * gamma, + final_reward, + ] + ] + ) + + _, returns = rlhf.estimate_advantages(values, rewards, gamma, lmbda) + torch.testing.assert_close(returns, expected_returns, rtol=1e-4, atol=1e-4) + + def test_estimate_advantages_with_whitening(self): + values = torch.tensor([[0, 0, 0, 1]]) + rewards = torch.tensor([[0, 0, 0, 1]]) + gamma = 0.9 + lmbda = 0.95 + + final_reward = 1.0 + returns = torch.tensor( + [ + [ + final_reward * gamma * gamma * gamma * lmbda * lmbda, + final_reward * gamma * gamma * lmbda, + final_reward * gamma, + final_reward, + ] + ] + ) + + # see `torchtune.modules.rlhf.estimate_advantages` + expected_advantages = returns - values + expected_whitened_advantages = rlhf.whiten(expected_advantages, shift_mean=True) + advantages, _ = rlhf.estimate_advantages(values, rewards, gamma, lmbda) + torch.testing.assert_close( + expected_whitened_advantages, advantages, rtol=1e-4, atol=1e-4 + ) + + def test_estimate_advantages_with_masks(self): + values = torch.tensor([[0, 0, 0, 1]]) + rewards = torch.tensor([[0, 0, 0, 1]]) + masks = torch.tensor([[True, True, True, False]]) + gamma = 0.9 + lmbda = 0.95 + + final_reward = 1.0 + returns = torch.tensor( + [ + [ + final_reward * gamma * gamma * gamma * lmbda * lmbda, + final_reward * gamma * gamma * lmbda, + final_reward * gamma, + final_reward, + ] + ] + ) + + # see `torchtune.modules.rlhf.estimate_advantages` + expected_advantages = returns - values + expected_advantages = rlhf.whiten(expected_advantages, mask=masks) + expected_advantages[..., -1] = 0.0 + + advantages, _ = rlhf.estimate_advantages( + values, rewards, gamma, lmbda, masks=masks + ) + torch.testing.assert_close( + advantages, expected_advantages, rtol=1e-4, atol=1e-4 + ) diff --git a/tests/torchtune/modules/rlhf/test_sequence_processing.py b/tests/torchtune/modules/rlhf/test_sequence_processing.py new file mode 100644 index 0000000000..43accdf80c --- /dev/null +++ b/tests/torchtune/modules/rlhf/test_sequence_processing.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torchtune.modules import rlhf + + +class TestTruncateSequenceAtFirstStopToken: + def test_truncate_sequences(self): + stop_token_ids = torch.tensor([2, 869]) + fill_value = 0 + sequences = torch.tensor( + [ + [869, 30, 869], + [2, 30, 869], + [869, 30, 2], + [50, 30, 869], + [13, 30, 2], + [13, 30, 5], + [13, 2, 20], + [13, 2, 2], + [2, 2, 2], + ] + ) + eos_mask, truncated_sequences = rlhf.truncate_sequence_at_first_stop_token( + sequences, stop_token_ids, fill_value + ) + + expected_eos_mask = torch.tensor( + [ + [False, True, True], + [False, True, True], + [False, True, True], + [False, False, False], + [False, False, False], + [False, False, False], + [False, False, True], + [False, False, True], + [False, True, True], + ] + ) + + expected_sequences = torch.tensor( + [ + [869, fill_value, fill_value], + [2, fill_value, fill_value], + [869, fill_value, fill_value], + [50, 30, 869], + [13, 30, 2], + [13, 30, 5], + [13, 2, fill_value], + [13, 2, fill_value], + [2, fill_value, fill_value], + ] + ) + + assert expected_eos_mask.eq(eos_mask).all() + assert expected_sequences.eq(truncated_sequences).all() diff --git a/tests/torchtune/utils/test_pooling.py b/tests/torchtune/utils/test_pooling.py index 13c870954c..223bcca33f 100644 --- a/tests/torchtune/utils/test_pooling.py +++ b/tests/torchtune/utils/test_pooling.py @@ -4,52 +4,53 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import torch -from torchtune.utils.pooling import pool_sequence_logits +from torchtune.utils.pooling import get_unmasked_sequence_lengths -class TestPooling: - def test_pool_sequence_logits_multi_batch(self): +class TestGetLastUnmaskedTokenIdx: + def test_get_last_unmasked_token_idx_multi_batch(self): """ - Tests that the last non-padding token logits are pooled correctly for a multi-batch input. + Tests that the last non-padding tokens are correctly selected for a multi-batch input. """ padding_token_idx = 0 tokens = torch.tensor([[1, 3, 4, 9], [4, 5, 6, 0], [1, 0, 0, 0], [0, 0, 0, 0]]) - logits = torch.tensor( - [ - [[0.1, 1.3, 1.4], [0.5, 0.6, 0.7], [0.9, 1.1, 1.2], [1.3, 0.5, 1.6]], - [[0.2, 1.4, 1.5], [0.6, 0.7, 0.8], [1.0, 1.2, 1.3], [1.4, 1.6, 0.7]], - [[0.3, 1.5, 1.6], [0.1, 1.8, 0.2], [1.1, 1.3, 1.4], [0.5, 1.7, 0.1]], - [[0.4, 1.6, 1.7], [0.8, 0.9, 1.0], [1.2, 1.4, 1.5], [0.6, 1.8, 0.2]], - ] - ) - expected_output = torch.tensor( - [ - [1.3, 0.5, 1.6], - [1.0, 1.2, 1.3], - [0.3, 1.5, 1.6], - [0.4, 1.6, 1.7], - ] - ) - output = pool_sequence_logits(tokens, logits, padding_token_idx) - torch.testing.assert_close(output, expected_output) + expected_output = torch.tensor([3, 2, 0, 0]) + idxs = get_unmasked_sequence_lengths(tokens == padding_token_idx) + torch.testing.assert_close(idxs, expected_output) - def test_pool_sequence_logits_single_batch(self): + def test_get_last_unmasked_token_idx_single_batch(self): """ - Tests that the last non-padding token logits are pooled correctly for a single-batch input. + Tests that the last non-padding tokens are correctly selected for a single-batch input. """ padding_token_idx = 0 - tokens = torch.tensor([[1, 3, 4, 9]]) - logits = torch.tensor( - [ - [[0.1, 1.3, 1.4], [0.5, 0.6, 0.7], [0.9, 1.1, 1.2], [1.3, 0.5, 1.6]], - ] - ) - expected_output = torch.tensor( - [ - [1.3, 0.5, 1.6], - ] - ) - output = pool_sequence_logits( - tokens, logits, padding_token_idx=padding_token_idx + tokens = torch.tensor([[1, 3, 4, 9, 0]]) + expected_output = torch.tensor([3]) + idxs = get_unmasked_sequence_lengths(tokens == padding_token_idx) + + torch.testing.assert_close(idxs, expected_output) + + def test_get_last_unmasked_token_idx_multi_batch_all_full(self): + """ + Tests that the last non-padding tokens are correctly selected for multi-batch input, + where none of the sequences have padding tokens. + """ + padding_token_idx = 0 + tokens = torch.tensor( + [[1, 3, 4, 9], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] ) - torch.testing.assert_close(output, expected_output) + expected_output = torch.tensor([3, 3, 3, 3]) + idxs = get_unmasked_sequence_lengths(tokens == padding_token_idx) + + torch.testing.assert_close(idxs, expected_output) + + def test_get_last_unmasked_token_idx_multi_batch_all_empty(self): + """ + Tests that the last non-padding tokens are correctly selected for multi-batch input, + where none of the sequences have any non-padding tokens. + """ + padding_token_idx = 0 + tokens = torch.zeros((4, 4), dtype=torch.long) + expected_output = torch.tensor([0, 0, 0, 0]) + idxs = get_unmasked_sequence_lengths(tokens == padding_token_idx) + + torch.testing.assert_close(idxs, expected_output) diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index 9f40bf6802..cf0cad5e20 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -177,6 +177,21 @@ class Recipe: ], supports_distributed=True, ), + Recipe( + name="ppo_full_finetune_single_device", + file_path="ppo_full_finetune_single_device.py", + configs=[ + Config( + name="llama2/1B_full_ppo", + file_path="llama2/1B_full_ppo.yaml", + ), + Config( + name="mistral/7B_full_ppo_low_memory", + file_path="mistral/7B_full_ppo_low_memory.yaml", + ), + ], + supports_distributed=False, + ), Recipe( name="lora_finetune_distributed", file_path="lora_finetune_distributed.py", diff --git a/torchtune/models/llama2/__init__.py b/torchtune/models/llama2/__init__.py index 432a128986..8682493c8d 100644 --- a/torchtune/models/llama2/__init__.py +++ b/torchtune/models/llama2/__init__.py @@ -41,6 +41,8 @@ "llama2_70b", "llama2_7b", "llama2_tokenizer", + "lora_llama2", + "llama2_classifier", "lora_llama2_13b", "lora_llama2_70b", "lora_llama2_7b", diff --git a/torchtune/modules/loss/__init__.py b/torchtune/modules/loss/__init__.py index 522bd868a7..5c02d17e0b 100644 --- a/torchtune/modules/loss/__init__.py +++ b/torchtune/modules/loss/__init__.py @@ -5,5 +5,6 @@ # LICENSE file in the root directory of this source tree. from .dpo import DPOLoss, IPOLoss, RSOLoss +from .ppo import PPOLoss -__all__ = ["DPOLoss", "RSOLoss", "IPOLoss"] +__all__ = ["DPOLoss", "RSOLoss", "IPOLoss", "PPOLoss"] diff --git a/torchtune/modules/loss/ppo.py b/torchtune/modules/loss/ppo.py new file mode 100644 index 0000000000..0cef4a5301 --- /dev/null +++ b/torchtune/modules/loss/ppo.py @@ -0,0 +1,120 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from torchtune.modules import rlhf + + +class PPOLoss(nn.Module): + """ + Proximal Policy Optimization (PPO) Loss module. + This implementation uses the following references: + + https://arxiv.org/abs/1707.06347 eqn. 7 + + https://github.com/vwxyzjn/lm-human-preference-details/blob/ccc19538e817e98a60d3253242ac15e2a562cb49/lm_human_preference_details/train_policy_accelerate.py#L719 + + https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/model.py#L68-L75 + + + Args: + epsilon (float): clipping range for PPO update. + value_clip_range (float): clipping range for value function update. + value_coeff (float): coefficient for the value function loss contribution. + """ + + def __init__( + self, + epsilon: float = 0.1, + value_clip_range: float = 0.2, + value_coeff: float = 0.1, + ): + super().__init__() + self.epsilon = epsilon + self.value_clip_range = value_clip_range + self.value_coeff = value_coeff + + def forward( + self, + pi_old_logprobs: torch.Tensor, + pi_logprobs: torch.Tensor, + advantages: torch.Tensor, + phi_old_values: torch.Tensor, + phi_values: torch.Tensor, + returns: torch.Tensor, + padding_masks: Optional[torch.Tensor] = None, + value_padding_masks: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + + Forward pass of the PPO loss module. + + Args: + pi_old_logprobs (torch.Tensor): Log probabilities of the old policy. + pi_logprobs (torch.Tensor): Log probabilities of the current policy. + advantages (torch.Tensor): Advantage values. + phi_old_values (torch.Tensor): Value predictions of the old value function. + phi_values (torch.Tensor): Value predictions of the current value function. + returns (torch.Tensor): Return values. + padding_masks (Optional[torch.Tensor]): Padding token masks of the same shape as ``pi_logprobs``, + where True indicates the corresponding loss values should participage in policy loss calculation. + value_padding_masks (Optional[torch.Tensor]): Padding token masks of the same shape as ``pi_logprobs``, + where True indicates the corresponding loss values should participage in value loss calculation. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: A tuple of five tensors: + - loss: The total PPO loss. + - policy_loss: The policy function loss. + - value_loss: The value function loss. + - ratios: The ratio between the current and old policy probabilities. + - clipfrac: The fraction of ratios that were clipped. + + """ + ratios = torch.exp(pi_logprobs - pi_old_logprobs) + clipped_ratios = torch.clamp(ratios, 1.0 - self.epsilon, 1.0 + self.epsilon) + + policy_losses_clipped = -advantages * clipped_ratios + policy_losses_unclipped = -advantages * ratios + + clipfrac = (policy_losses_clipped > policy_losses_unclipped).float() + clipfrac = ( + clipfrac.mean() + if padding_masks is None + else rlhf.masked_mean(clipfrac, padding_masks) + ) + + policy_loss = torch.maximum(policy_losses_clipped, policy_losses_unclipped) + policy_loss = ( + policy_loss.mean() + if padding_masks is None + else rlhf.masked_mean(policy_loss, padding_masks) + ) + + values_clipped = torch.clamp( + phi_values, + phi_old_values - self.value_clip_range, + phi_old_values + self.value_clip_range, + ) + value_loss = torch.maximum( + (phi_values - returns) ** 2, (values_clipped - returns) ** 2 + ) + value_loss = ( + 0.5 * value_loss.mean() + if value_padding_masks is None + else 0.5 * rlhf.masked_mean(value_loss, value_padding_masks) + ) + + loss = policy_loss + (value_loss * self.value_coeff) + return ( + loss, + policy_loss.detach(), + value_loss.detach(), + ratios.mean().detach(), + clipfrac.detach(), + ) diff --git a/torchtune/modules/peft/peft_utils.py b/torchtune/modules/peft/peft_utils.py index b3c55ae389..eb86c6afd2 100644 --- a/torchtune/modules/peft/peft_utils.py +++ b/torchtune/modules/peft/peft_utils.py @@ -320,7 +320,7 @@ def validate_missing_and_unexpected_for_lora( Raises: AssertionError: if base_missing contains any base model keys. - AssertionError: if base_unexpect is nonempty. + AssertionError: if base_unexpected is nonempty. AssertionError: if lora_missing contains any LoRA keys. AssertionError: if lora_unexpected is nonempty. """ diff --git a/torchtune/modules/rlhf/__init__.py b/torchtune/modules/rlhf/__init__.py index 2e41cd717f..307f24f801 100644 --- a/torchtune/modules/rlhf/__init__.py +++ b/torchtune/modules/rlhf/__init__.py @@ -3,3 +3,44 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + +from ._generation import ( + generate_next_token_with_logits, + generate_with_logits, + get_causal_mask, +) + +from ._types import PPOStats, Trajectory +from .collate import left_padded_collate, padded_collate_dpo +from .rewards import ( + estimate_advantages, + get_reward_penalty_mask, + get_rewards_ppo, + masked_mean, + masked_var, + whiten, +) +from .sequence_processing import ( + logits_to_logprobs, + truncate_sequence_at_first_stop_token, + truncate_sequence_for_logprobs, +) + +__all__ = [ + "generate_with_logits", + "generate_next_token_with_logits", + "truncate_sequence_at_first_stop_token", + "get_causal_mask", + "logits_to_logprobs", + "truncate_sequence_for_logprobs", + "get_reward_penalty_mask", + "left_padded_collate", + "padded_collate_dpo", + "estimate_advantages", + "get_rewards_ppo", + "whiten", + "masked_mean", + "masked_var", + "PPOStats", + "Trajectory", +] diff --git a/torchtune/modules/rlhf/_generation.py b/torchtune/modules/rlhf/_generation.py new file mode 100644 index 0000000000..a7f54b6fe5 --- /dev/null +++ b/torchtune/modules/rlhf/_generation.py @@ -0,0 +1,172 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import torch +from torchtune.modules.transformer import TransformerDecoder + + +def multinomial_sample_one( + probs: torch.Tensor, rng: Optional[torch.Generator] = None +) -> torch.Tensor: + """Samples from a multinomial distribution.""" + q = torch.empty_like(probs).exponential_(1, generator=rng) + return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int) + + +def sample( + logits: torch.Tensor, + temperature: float = 1.0, + top_k: int = None, + rng: Optional[torch.Generator] = None, +) -> torch.Tensor: + """Generic sample from a probability distribution.""" + # scale the logits based on temperature + logits = logits / max(temperature, 1e-5) + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + # select the very last value from the top_k above as the pivot + pivot = v.select(-1, -1).unsqueeze(-1) + # set everything smaller than pivot value to inf since these + # should be pruned + logits = torch.where(logits < pivot, -float("Inf"), logits) + # change logits into probabilities + probs = torch.nn.functional.softmax(logits, dim=-1) + return multinomial_sample_one(probs, rng) + + +def generate_next_token_with_logits( + model: TransformerDecoder, + input_pos: torch.Tensor, + x: torch.Tensor, + *, + mask: Optional[torch.Tensor] = None, + temperature: float = 1.0, + top_k: Optional[int] = None, + rng: Optional[torch.Generator] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Generates the next tokens given a prompt, and also returns the corresponding logits. + + Args: + model (TransformerDecoder): model used for generation + input_pos (torch.Tensor): tensor with the positional encodings associated with the given prompt, + with shape [bsz x seq_length]. + x (torch.Tensor): tensor with the token IDs associated with the given prompt, + with shape [bsz x seq_length]. + mask (Optional[torch.Tensor]): attention mask with shape [bsz x seq_length x seq_length], + default None. + temperature (float): value to scale the predicted logits by, default 1.0. + top_k (Optional[int]): Top-k value to use for sampling, default None. + rng (Optional[torch.Generator]): random number generator, default None. + Returns: + Tuple[torch.Tensor, torch.Tensor]: tuple of two tensors: + - logits (torch.Tensor): tensor with the logits associated with the generated tokens, + with shape [bsz x seq_length x vocab_size]. + - tokens (torch.Tensor): tensor with the generated tokens, + with shape [bsz x 1]. + + """ + # model produces logits in [bsz, seq_length, vocab_size] + # we want to take the last token's logits as the input to the next model call + logits = model(x, input_pos=input_pos, mask=mask) + return logits, sample(logits[:, -1].clone(), temperature, top_k, rng) + + +def get_causal_mask( + padding_mask: torch.Tensor, +) -> torch.Tensor: + """ + Converts an attention mask of shape ``[bsz, seq_len]`` to a causal attention mask suitable for + consumption by :func:`~torch.nn.functional.scaled_dot_product_attention~`. + + HF uses a similar implementation internally, see + https://github.com/huggingface/transformers/blob/a564d10afe1a78c31934f0492422700f61a0ffc0/src/transformers/models/mistral/modeling_mistral.py#L1096 + + Args: + padding_mask (torch.Tensor): Boolean tensor where True indicates participation in attention + with shape [bsz x seq_length] + Returns: + torch.Tensor: Boolean causal mask with shape [bsz x seq_length x seq_length] + """ + _, seq_len = padding_mask.shape + mask = torch.tril( + torch.ones(seq_len, seq_len, device=padding_mask.device, dtype=bool), diagonal=0 + ) + mask = mask & (padding_mask[:, None, :] & padding_mask[:, :, None]) + mask.diagonal(dim1=1, dim2=2)[:] = True + return mask + + +@torch.inference_mode() +def generate_with_logits( + model: TransformerDecoder, + prompt: torch.Tensor, + *, + max_generated_tokens: int, + pad_id: int = 0, + temperature: float = 1.0, + top_k: Optional[int] = None, + rng: Optional[torch.Generator] = None, +): + """ + Generates tokens from a model conditioned on a prompt, and also returns logits for the generations. + + Args: + model (TransformerDecoder): model used for generation + prompt (torch.Tensor): tensor with the token IDs associated with the given prompt, + with shape either [seq_length] or [bsz x seq_length]. + max_generated_tokens (int): number of tokens to be generated + pad_id (int): token ID to use for padding, default 0. + temperature (float): value to scale the predicted logits by, default 1.0. + top_k (Optional[int]): If specified, we prune the sampling to only token ids within the top_k probabilities, + default None. + rng (Optional[torch.Generator]): random number generator, default None. + + Examples: + >>> model = torchtune.models.llama3.llama3_8b() + >>> tokenizer = torchtune.models.llama3.llama3_tokenizer() + >>> prompt = [0, 0, 0] + tokenizer("Hi my name is") # substitute 0 with pad_id + >>> rng = torch.Generator() # optionally place on device + >>> rng.manual_seed(42) + >>> output = generate(model, torch.tensor(prompt), max_generated_tokens=100, pad_id=0, rng=rng) + >>> print(tokenizer.decode(output[0])) + ?? ?? ?? Hi my name is Jeremy and I'm a friendly language model assistant! + + Returns: + torch.Tensor: Generated tokens. + """ + prompt = prompt.view(1, -1) if prompt.ndim == 1 else prompt + + _, prompt_length = prompt.size() + generated_tokens = prompt.clone() + + for i in range(max_generated_tokens): + padding_masks = generated_tokens == pad_id + if padding_masks.any(): + mask = get_causal_mask(~padding_masks) + input_pos = (~padding_masks).cumsum(-1) - (~padding_masks).long() + input_pos = input_pos.to(torch.int) + else: + mask = None + input_pos = torch.arange( + 0, prompt_length + i, device=generated_tokens.device + ) + + logits, tokens = generate_next_token_with_logits( + model, + input_pos=input_pos, + x=generated_tokens, + mask=mask, + temperature=temperature, + top_k=top_k, + rng=rng, + ) + + generated_tokens = torch.cat([generated_tokens, tokens], dim=-1) + + return generated_tokens, logits diff --git a/torchtune/modules/rlhf/_types.py b/torchtune/modules/rlhf/_types.py new file mode 100644 index 0000000000..729a4035fc --- /dev/null +++ b/torchtune/modules/rlhf/_types.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import NamedTuple + +import torch + + +class Trajectory(NamedTuple): + """ + Contains a collection of tensors describing a generated trajectory during RLHF + + Attributes: + query_responses (torch.Tensor): (query, response) pairs + shape [b, context_length + max_generated_tokens] + logprobs (torch.Tensor): log probabilities of the generated responses with shape [b, max_generated_tokens] + ref_logprobs (torch.Tensor): log probabilities of the generated responses using the reference policy + shape [b, max_generated_tokens] + values (torch.Tensor): value estimates of the generated responses with shape [b, max_generated_tokens] + masks (torch.Tensor): attention masks for input ids-generated responses pairs + shape [b, context_length + max_generated_tokens, context_length + max_generated_tokens] + position_ids (torch.Tensor): position IDs for input ids-generated responses pairs + shape [b, context_length + max_generated_tokens] + response_padding_masks (torch.Tensor): padding masks for the truncated and padded generated responses + shape [b, max_generated_tokens] + value_padding_masks (torch.Tensor): padding masks for the values with + shape [b, max_generated_tokens] + value_seq_idxs (torch.Tensor): indexes of the token + after the last valid (non-padding) token in the responses with shape [b] + scores (torch.Tensor): scores from the reward model with shape [b] + seq_lens (torch.Tensor): sequence lengths of truncated generated responses with shape [b] + """ + + query_responses: torch.Tensor + logprobs: torch.Tensor + ref_logprobs: torch.Tensor + values: torch.Tensor + masks: torch.Tensor + position_ids: torch.Tensor + response_padding_masks: torch.Tensor + value_padding_masks: torch.Tensor + value_seq_idxs: torch.Tensor + scores: torch.Tensor + seq_lens: torch.Tensor + + +class PPOStats(NamedTuple): + """ + Contains PPO loss statistics (metrics) + + Attributes: + loss (torch.Tensor): The total PPO loss. + policy_loss (torch.Tensor): The policy function loss. + value_loss (torch.Tensor): The value function loss. + ratios (torch.Tensor): The ratio between the current and old policy probabilities. + clipfrac (torch.Tensor): The fraction of ratios that were clipped. + approx_policy_kls (torch.Tensor): Average estimated KL divergence between the policy before and after the optimisation step. + + """ + + loss: torch.Tensor + policy_loss: torch.Tensor + value_loss: torch.Tensor + ratios: torch.Tensor + clipfrac: torch.Tensor + approx_policy_kls: torch.Tensor diff --git a/torchtune/modules/rlhf/collate.py b/torchtune/modules/rlhf/collate.py new file mode 100644 index 0000000000..fa20beb2b0 --- /dev/null +++ b/torchtune/modules/rlhf/collate.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Tuple + +import torch +from torch.nn.utils.rnn import pad_sequence + +from torchtune.data import CROSS_ENTROPY_IGNORE_IDX + + +def left_padded_collate( + batch: List[Dict[str, List[int]]], + padding_idx: int = 0, +) -> torch.Tensor: + """ + Pads a batch of sequences with left padding to the maximum sequence length in the batch. + + Args: + batch (List[Dict[str, List[int]]]): A list of dictionaries containing inputs. + padding_idx (int): The padding index. Defaults to 0. + + Returns: + torch.Tensor: The padded tensor of input ids with shape [batch_size, max_seq_len]. + + Example: + >>> padding_idx = -8 + >>> batch = [ + >>> {"tokens": [1, 2] }, + >>> {"tokens": [3] }, + >>> {"tokens": [4, 5, 6, 7]}, + >>> ] + >>> left_padded_collate(batch, padding_idx) + >>> tensor([[-8, -8, 1, 2], + >>> [-8, -8, -8, 3], + >>> [ 4, 5, 6, 7]]) + + """ + pad_toks = pad_sequence( + [torch.tensor(x["tokens"][::-1]) for x in batch], + batch_first=True, + padding_value=padding_idx, + ) + seq_idxs_rev = torch.arange(pad_toks.shape[-1] - 1, -1, -1) + return torch.stack([tok[seq_idxs_rev] for tok in pad_toks]) + + +def padded_collate_dpo( + batch: List[Dict[str, List[int]]], + padding_idx: int = 0, + ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Pad a batch of sequences for Direct Preference Optimization (DPO). + + This function takes a batch of sequences, where each sequence is represented + as a dictionary with multiple key-value pairs. Each key corresponds to a different + sequence component, such as input_ids or labels. + + This will raise: + AssertionError: if the length of chosen_input_ids and rejected_input_ids differ. + AssertionError: if the length of chosen_labels and rejected_labels differ. + + Args: + batch (List[Dict[str, List[int]]]): A list of dictionaries, where each dictionary + represents a sequence with multiple components, 'chosen_input_ids', + 'chosen_labels', 'rejected_input_ids', and 'rejected_labels' are required. + padding_idx (int): Padding index for input ids. Defaults to 0. + ignore_idx (int): Padding index for labels. Defaults to -100. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing concatenated and padded + input ids and labels. + + + Example: + >>> batch = [ + >>> {'chosen_input_ids': [1, 2, 3], 'rejected_input_ids': [4, 5], + >>> 'chosen_labels': [6, 7, 8], 'rejected_labels': [9, 10]}, + >>> {'chosen_input_ids': [11, 12], 'rejected_input_ids': [13, 14, 15], + >>> 'chosen_labels': [16, 17], 'rejected_labels': [18, 19, 20]}, + >>> ] + >>> padded_collate_dpo(batch) + >>> (tensor([[ 1, 2, 3], + >>> [11, 12, 0], + >>> [ 4, 5, 0], + >>> [13, 14, 15]]), + >>> tensor([[ 6, 7, 8], + >>> [16, 17, -100], + >>> [ 9, 10, -100], + >>> [18, 19, 20]])) + """ + chosen_input_ids = [torch.tensor(ex["chosen_input_ids"]) for ex in batch] + rejected_input_ids = [torch.tensor(ex["rejected_input_ids"]) for ex in batch] + chosen_labels = [torch.tensor(ex["chosen_labels"]) for ex in batch] + rejected_labels = [torch.tensor(ex["rejected_labels"]) for ex in batch] + + assert len(chosen_input_ids) == len(rejected_input_ids) + assert len(chosen_labels) == len(rejected_labels) + + to_pad_input_ids = chosen_input_ids + rejected_input_ids + to_pad_labels = chosen_labels + rejected_labels + + concatenated_input_ids = pad_sequence( + to_pad_input_ids, batch_first=True, padding_value=padding_idx + ) + concatenated_labels = pad_sequence( + to_pad_labels, batch_first=True, padding_value=ignore_idx + ) + + return concatenated_input_ids, concatenated_labels diff --git a/torchtune/modules/rlhf/rewards.py b/torchtune/modules/rlhf/rewards.py new file mode 100644 index 0000000000..0e5994ff1d --- /dev/null +++ b/torchtune/modules/rlhf/rewards.py @@ -0,0 +1,237 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import torch + + +def get_reward_penalty_mask( + padding_masks: torch.Tensor, + seq_lens: torch.Tensor, + penalise_no_eos: bool = True, + min_response_length: int = None, +) -> torch.Tensor: + """ + Calculates a mask to penalise scores corresponding to sequences generated during PPO, where True indicates the score + at the corresponding position should be penalised. + This function assumes sequences have already been truncated at an EOS, if present, and padded to length, + e.g. by :func:`torchtune.modules.rlhf.sequence_processing.truncate_sequence_at_first_stop_token`. + + Scores are penalised such that: + - If ``min_response_length`` is set, scores for sequences with ``length < min_response_length`` are penalised. + - If ``penalise_no_eos`` is True, scores for sequences with no EOS token are penalised. + + Args: + padding_masks (torch.Tensor): Tensor where True indicates a padding token in the generated + sequence, and False otherwise. Shape: (b, reponse_len) + seq_lens (torch.Tensor): The length of each generated sequence. Shape: (b,) + penalise_no_eos (bool, optional): Whether to penalise sequences with no EOS token. Defaults to True. + min_response_length (int, optional): The minimum length of the response. If set, any responses is shorter + than this length will be penalised. Defaults to None. + Returns: + torch.Tensor: A mask tensor with shape (b,) where True indicates the corresponding score should be penalised. + """ + reward_penalty_mask = torch.zeros_like(seq_lens).to(bool) + + # since sequences will have been truncated at EOS, we can mask based on the presence of any padding tokens + if penalise_no_eos: + reward_penalty_mask = ~padding_masks.any(-1) + + if min_response_length is not None: + reward_penalty_mask |= ~(seq_lens >= min_response_length) + return reward_penalty_mask + + +def get_rewards_ppo( + scores: torch.Tensor, + logprobs: torch.Tensor, + ref_logprobs: torch.Tensor, + kl_coeff: float, + valid_score_idxs: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Calculates PPO rewards for the given scores, logprobs, and reference logprobs. + + Args: + scores (torch.Tensor): Reward model scores, shape (b,). + logprobs (torch.Tensor): Policy logprobs, shape (b, reponse_len). + ref_logprobs (torch.Tensor): Reference base model, shape (b, reponse_len). + kl_coeff (float): KL reward contribution coefficient. + valid_score_idxs (Optional[torch.Tensor]): A tensor of indexes for valid (non-padded) token predictions. + This is useful when calculating rewards for padded sequences, as scores and value estimates are defined + for the last valid predicted token. Shape: (b,). Default None. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple of tensors with shape [b, response_len] each: + - total_reward: total reward combining per-token kl rewards and reward model score. + - kl: kl divergence between policy and reference policy logprobs. + - kl_reward: kl divergence scaled by ``kl_coeff``. + + Notation used for tensor shapes: + - b: batch size + - response_len: model response length + """ + + # 1. calculate kl between logprobs and reflogprobs + # 2. calculate kl reward using adaptive scaling value + # 3. calculate total reward by summing above + # return all + kl = logprobs - ref_logprobs + kl_reward = -kl_coeff * kl + + total_reward = kl_reward.clone() + + # adding reward to kl at final valid position + # https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L153 + + if valid_score_idxs is not None: + total_reward[ + torch.arange(scores.shape[0], device=scores.device), valid_score_idxs + ] += scores + else: + total_reward[:, -1] += scores + + return total_reward, kl, kl_reward + + +def masked_mean( + x: torch.Tensor, mask: torch.Tensor, dim: Optional[int] = None +) -> torch.Tensor: + """ + Compute mean of tensor with masked values. Taken from https://github.com/huggingface/trl/blob/main/trl/core.py + + Args: + x (torch.Tensor): The input tensor. + mask (torch.Tensor): The bool mask tensor, where True indicates the corresponding value in ``x`` + should participate in the mean calculation. + dim (Optional[int]): The axis to calculate the mean over. Default None. + + Returns: + torch.Tensor: The mean tensor. + """ + return (x * mask).sum(dim=dim) / mask.sum(dim=dim) + + +def masked_var( + x: torch.Tensor, mask: torch.Tensor, unbiased: bool = True +) -> torch.Tensor: + """ + Compute variance of tensor with masked values. Taken from https://github.com/huggingface/trl/blob/main/trl/core.py + + Args: + x (torch.Tensor): The input tensor. + mask (torch.Tensor): The bool mask tensor, where True indicates the corresponding value in ``x`` + should participate in the mean calculation. + unbiased (bool): Whether to use the unbiased variance. + + Returns: + torch.Tensor: The variance tensor. + + Raises: + ValueError: If the sum of the mask is zero. + """ + mean = masked_mean(x, mask) + centered_values = x - mean + var = masked_mean(centered_values.pow(2), mask) + if unbiased: + mask_sum = mask.sum() + if mask_sum == 0: + raise ValueError( + "The sum of the mask is zero, which can happen when ``ppo_batch_size=1``;" + "try increase the ``ppo_batch_size`` or ``gradient_accumulation_steps``" + ) + # note that if mask_sum == 1, then there is a division by zero issue + # to avoid it you just need to use a larger minibatch_size + bessel_correction = mask_sum / (mask_sum - 1) + var = var * bessel_correction + return var + + +def whiten( + x: torch.Tensor, mask: Optional[torch.Tensor] = None, shift_mean: bool = True +) -> torch.Tensor: + """ + Whiten (normalises) values, optionally with masked values. Taken from https://github.com/huggingface/trl/blob/main/trl/core.py + Args: + x (torch.Tensor): The input tensor. + mask (Optional[torch.Tensor]): The bool mask tensor, where True indicates the corresponding value in ``x`` + should participate in the mean calculation. Default None. + shift_mean (bool): Whether to shift normalised values by the mean. + + Returns: + torch.Tensor: The whitened tensor. + """ + if mask is not None: + mean = masked_mean(x, mask) + var = masked_var(x, mask) if mask.any() else x.var() + else: + mean, var = x.mean(), x.var() + whitened = (x - mean) * torch.rsqrt(var + 1e-8) + if shift_mean: + whitened += mean + return whitened + + +def estimate_advantages( + values: torch.Tensor, + rewards: torch.Tensor, + gamma: float, + lmbda: float, + masks: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Estimates the advantages and returns for the PPO algorithm using Generalized Advantage Estimation + https://arxiv.org/pdf/1506.02438.pdf + + Args: + values (torch.Tensor): The predicted values for each state. Shape: (b, reponse_len) + rewards (torch.Tensor): The rewards received at each time step. Shape: (b, reponse_len) + gamma (float): The discount factor. + lmbda (float): The GAE-Lambda parameter. + masks (Optional[torch.Tensor]): A bool mask tensor, where True indicates the corresponding value in ``values`` + should participate in the mean calculation. Default None. + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing the estimated advantages and returns. + - advantages (torch.Tensor): The estimated advantages. Shape: (b, reponse_len) + - returns (torch.Tensor): The estimated returns. Shape: (b, reponse_len) + Notation: + - b: batch size + - reponse_len: model response length + """ + + last_gae_lam = 0 + advantages_reversed = [] + + response_length = values.shape[-1] + + # estimate advantage for every predicted token position + for t in reversed(range(response_length)): + # value of the next state + next_values = values[:, t + 1] if t < response_length - 1 else 0.0 + # exponentially discounted temporal difference error: + # delta_t = r_t + gamma * V(s_{t+1}) - V(s_t) + delta = rewards[:, t] + gamma * next_values - values[:, t] + # GAE-Lambda advantage discounting saved for the next iteration + # as A_t = delta_t + gamma * lambda * A_{t+1} + ... + last_gae_lam = delta + gamma * lmbda * last_gae_lam + advantages_reversed.append(last_gae_lam) + + advantages = torch.stack(advantages_reversed[::-1], axis=1) + + # returns are the expected value of taking action a_t at each timepoint over + # a trajectory. the value estimates v_t are the expected value over all actions + # over a trajectory - the advantage is the difference between the two + returns = advantages + values + + # normalize advantages across the batch of trajectories to reduce variance + if masks is not None: + advantages = whiten(advantages, mask=masks) + advantages[~masks] = 0.0 + else: + advantages = whiten(advantages) + + return advantages, returns diff --git a/torchtune/modules/rlhf/sequence_processing.py b/torchtune/modules/rlhf/sequence_processing.py new file mode 100644 index 0000000000..58f6bf3149 --- /dev/null +++ b/torchtune/modules/rlhf/sequence_processing.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +import torch.nn.functional as F + + +def truncate_sequence_at_first_stop_token( + sequences: torch.Tensor, stop_tokens: torch.Tensor, fill_value: int = 0 +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Truncates sequence(s) after the first stop token and pads with ``fill_value``. + + Args: + sequences (torch.Tensor): tensor of shape [batch_size, sequence_length] or [sequence_length]. + stop_tokens (torch.Tensor): tensor containing stop tokens. + fill_value (int): value to pad the sequence with after the first stop token, usually ``pad_id``. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple of two tensors with the same shape as ``sequences``: + - padding_mask (torch.Tensor): a bool tensor where True indicates the token has been truncated. + - sequences (torch.Tensor) a tensor of truncated and padded sequences. + + Example: + >>> stop_token_ids = torch.tensor([2, 869]) + >>> fill_value = 0 + >>> sequences = torch.tensor( + >>> [ + >>> [869, 30, 869], + >>> [2, 30, 869], + >>> [869, 30, 2], + >>> [50, 30, 869], + >>> [13, 30, 2], + >>> [13, 30, 5], + >>> [13, 2, 20], + >>> [13, 2, 2], + >>> [2, 2, 2], + >>> ] + >>> ) + >>> eos_mask, truncated_sequences = rlhf.truncate_sequence_at_first_stop_token( + >>> sequences, stop_token_ids, fill_value + >>> ) + >>> eos_mask + >>> torch.tensor([ + >>> [False, True, True], + >>> [False, True, True], + >>> [False, True, True], + >>> [False, False, False], + >>> [False, False, False], + >>> [False, False, False], + >>> [False, False, True], + >>> [False, False, True], + >>> [False, True, True], + >>> ] + >>> ) + >>> truncated_sequences + >>> torch.tensor([ + >>> [869, 0, 0], + >>> [2, 0, 0], + >>> [869, 0, 0], + >>> [50, 30, 869], + >>> [13, 30, 2], + >>> [13, 30, 5], + >>> [13, 2, 0], + >>> [13, 2, 0], + >>> [2, 0, 0], + >>> ] + >>> ) + """ + eos_mask = torch.isin(sequences, stop_tokens) + seq_lens = torch.cumsum(eos_mask, dim=1) + padding_mask = (seq_lens > 1) | ((seq_lens == 1) & ~eos_mask) + sequences[padding_mask] = fill_value + return padding_mask, sequences + + +def logits_to_logprobs( + logits: torch.Tensor, sequences: torch.Tensor, temperature: float = 1.0 +) -> torch.Tensor: + """ + Converts logits corresponding to a generated sequence to logprobs over the generated tokens. + + Args: + logits (torch.Tensor): The logits tensor of shape [b, response_length, vocab_size]. + sequences (torch.Tensor): The corresponding tokens of shape [b, response_length]. + temperature (float): The temperature to scale the logits. Default 1.0 + Returns: + torch.Tensor: The log probabilities corresponding to each token in ``sequences``. Shape [b, response_length]. + """ + return torch.gather( + F.log_softmax(logits / temperature, dim=-1), + 2, + sequences.unsqueeze(-1), + ).squeeze(-1) + + +def truncate_sequence_for_logprobs( + query_response_logits: torch.Tensor, context_length: int +) -> torch.Tensor: + """ + Truncates logits generated over a sequence for estimating logprobs over the tokens in the sequence. + This assumes the sequence is of the (query, response) format with length (context_length + response_length) + Args: + query_response_logits (torch.Tensor): The logits tensor of shape [b, context_length + response_length, vocab_size]. + context_length (int): The length of the context. + + Returns: + torch.Tensor: The truncated logits for the response with shape [b, response_length, vocab_size].""" + return query_response_logits[:, context_length - 1 : -1] diff --git a/torchtune/utils/__init__.py b/torchtune/utils/__init__.py index 300ab57ed4..c03c444106 100644 --- a/torchtune/utils/__init__.py +++ b/torchtune/utils/__init__.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. from ._checkpointing import ( # noqa + Checkpointer, FullModelHFCheckpointer, FullModelMetaCheckpointer, FullModelTorchTuneCheckpointer, @@ -40,7 +41,7 @@ ) from ._version import torch_version_ge from .argparse import TuneRecipeArgumentParser -from .collate import padded_collate, padded_collate_dpo +from .collate import padded_collate from .constants import ( # noqa ADAPTER_CONFIG, ADAPTER_KEY, @@ -48,7 +49,9 @@ MAX_STEPS_KEY, MODEL_KEY, OPT_KEY, + RNG_KEY, SEED_KEY, + STEPS_KEY, TOTAL_EPOCHS_KEY, ) from .logging import get_logger @@ -61,6 +64,7 @@ register_optim_in_bwd_hooks, set_activation_checkpointing, ) +from .pooling import get_unmasked_sequence_lengths from .precision import get_dtype, set_default_dtype, validate_expected_param_dtype from .quantization import get_quantizer_mode @@ -79,7 +83,7 @@ "lora_fsdp_wrap_policy", "get_full_finetune_fsdp_wrap_policy", "padded_collate", - "padded_collate_dpo", + "get_unmasked_sequence_lengths", "set_activation_checkpointing", "set_default_dtype", "set_seed", diff --git a/torchtune/utils/_checkpointing/__init__.py b/torchtune/utils/_checkpointing/__init__.py index 30c7e5e1fb..2c9a83da90 100644 --- a/torchtune/utils/_checkpointing/__init__.py +++ b/torchtune/utils/_checkpointing/__init__.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Union from ._checkpointer import ( # noqa FullModelHFCheckpointer, @@ -11,9 +12,17 @@ ) from ._checkpointer_utils import ModelType # noqa + +Checkpointer = Union[ + FullModelHFCheckpointer, + FullModelMetaCheckpointer, + FullModelTorchTuneCheckpointer, +] + __all__ = [ "FullModelHFCheckpointer", "FullModelMetaCheckpointer", "FullModelTorchTuneCheckpointer", "ModelType", + "Checkpointer", ] diff --git a/torchtune/utils/collate.py b/torchtune/utils/collate.py index 1cdc5362ee..e0b779ffbd 100644 --- a/torchtune/utils/collate.py +++ b/torchtune/utils/collate.py @@ -3,7 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, List, Tuple +from typing import Dict, List import torch @@ -21,7 +21,7 @@ def padded_collate( convert integer lists to tensors. Args: - batch (List[Dict[str, List[int]]]): A list of tuples containing input, label pairs. + batch (List[Dict[str, List[int]]]): A list of dictionaries containing input, label pairs. padding_idx (int): Padding index for input ids. Defaults to 0. ignore_idx (int): Padding index for labels. Defaults to -100. @@ -69,67 +69,3 @@ def padded_collate( value=padding_idx, ) return {"tokens": input_ids.long(), "labels": labels.long()} - - -def padded_collate_dpo( - batch: List[Dict[str, List[int]]], - padding_idx: int = 0, - ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Pad a batch of sequences for Direct Preference Optimization (DPO). - - This function takes a batch of sequences, where each sequence is represented - as a dictionary with multiple key-value pairs. Each key corresponds to a different - sequence component, such as input_ids or labels. - - This function will throw an AssertionError if: - - the length of chosen_input_ids and rejected_input_ids differ. - - the length of chosen_labels and rejected_labels differ. - - Args: - batch (List[Dict[str, List[int]]]): A list of dictionaries, where each dictionary - represents a sequence with multiple components, 'chosen_input_ids', - 'chosen_labels', 'rejected_input_ids', and 'rejected_labels' are required. - padding_idx (int): Padding index for input ids. Defaults to 0. - ignore_idx (int): Padding index for labels. Defaults to -100. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing concatenated and padded - input ids and labels. - - Example: - >>> batch = [ - >>> {'chosen_input_ids': [1, 2, 3], 'rejected_input_ids': [4, 5], - >>> 'chosen_labels': [6, 7, 8], 'rejected_labels': [9, 10]}, - >>> {'chosen_input_ids': [11, 12], 'rejected_input_ids': [13, 14, 15], - >>> 'chosen_labels': [16, 17], 'rejected_labels': [18, 19, 20]}, - >>> ] - >>> padded_collate_dpo(batch) - >>> (tensor([[ 1, 2, 3], - >>> [11, 12, 0], - >>> [ 4, 5, 0], - >>> [13, 14, 15]]), - >>> tensor([[ 6, 7, 8], - >>> [16, 17, -100], - >>> [ 9, 10, -100], - >>> [18, 19, 20]])) - """ - chosen_input_ids = [torch.tensor(ex["chosen_input_ids"]) for ex in batch] - rejected_input_ids = [torch.tensor(ex["rejected_input_ids"]) for ex in batch] - chosen_labels = [torch.tensor(ex["chosen_labels"]) for ex in batch] - rejected_labels = [torch.tensor(ex["rejected_labels"]) for ex in batch] - - assert len(chosen_input_ids) == len(rejected_input_ids) - assert len(chosen_labels) == len(rejected_labels) - - to_pad_input_ids = chosen_input_ids + rejected_input_ids - to_pad_labels = chosen_labels + rejected_labels - - concatenated_input_ids = pad_sequence( - to_pad_input_ids, batch_first=True, padding_value=padding_idx - ) - concatenated_labels = pad_sequence( - to_pad_labels, batch_first=True, padding_value=ignore_idx - ) - - return concatenated_input_ids, concatenated_labels diff --git a/torchtune/utils/constants.py b/torchtune/utils/constants.py index 793ffa4824..30b4a7ae8c 100644 --- a/torchtune/utils/constants.py +++ b/torchtune/utils/constants.py @@ -22,3 +22,7 @@ # total number of epochs for training; resumed training runs for # (total_epochs - epochs_run) number of epochs TOTAL_EPOCHS_KEY = "total_epochs" +# number of steps completed thus far - for PPO +STEPS_KEY = "steps_run" +# rng state for ensuring correct training resuming in PPO +RNG_KEY = "rng_state" diff --git a/torchtune/utils/pooling.py b/torchtune/utils/pooling.py index a0d72c51c0..ffbf5fad89 100644 --- a/torchtune/utils/pooling.py +++ b/torchtune/utils/pooling.py @@ -4,38 +4,39 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import torch -from torch import Tensor -def pool_sequence_logits( - tokens: Tensor, logits: Tensor, padding_token_idx: int -) -> Tensor: - """Pool sequence logits by selecting the predicted logits for the last non-padding token - for each sequence in the batch. +def get_unmasked_sequence_lengths(mask: torch.Tensor) -> torch.Tensor: + """ + Returns the sequence lengths for each batch element, excluding masked tokens. Args: - tokens (Tensor): input tensor with shape [b x s] - logits (Tensor): predicted logits for input tokens with shape [b x s x n] - padding_token_idx (int): Padding token id used in the tokenizer. + mask (torch.Tensor): Boolean mask with shape [b x s], where True indicates a value to be masked out + - this is usually a mask for padding tokens, where True indicates a padding token + Returns: - Tensor: Pooled logits with shape [b x n] + Tensor: Sequence indexes logits with shape [b] + Notation used for tensor shapes: - b: batch size - s: sequence length - - n: number of classes - """ - batch_size = tokens.shape[0] - # inspired by the HF implementation: - # https://github.com/huggingface/transformers/blob/928331381ef6ce0622c0b1ac704299046b3afa21/src/transformers/models/mistral/modeling_mistral.py#L1339 - - # calculate per-batch-element sequence lengths by finding EOS padding tokens - padding_mask = tokens == padding_token_idx - if padding_mask.any(): + Example: + >>> input_ids = torch.tensor([ + >>> [2, 4, 0, 0], + >>> [2, 4, 6, 0], + >>> [2, 4, 6, 9] + >>> ]) + >>> get_last_unmasked_token_idx(input_ids == 0) + >>> tensor([1, 2, 3]) + """ + # calculate per-batch-element sequence lengths by finding last valid tokens + if mask.any(): sequence_lengths = ( - padding_mask.logical_not().sum(-1).to(logits.device).sub(1).clip(0) + (~mask).sum(-1).sub(1).clip(0).to(mask.device, dtype=torch.long) ) else: - sequence_lengths = -1 + sequence_lengths = torch.full( + (mask.shape[0],), mask.shape[1] - 1, dtype=torch.long, device=mask.device + ) - # grab logits for the last non-padding token for each sequence in the batch - return logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + return sequence_lengths