Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Doc] Improve docs #91

Merged
merged 15 commits into from
Jan 18, 2023
6 changes: 6 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
- local: installation
title: Installation
title: Get started
- sections:
- local: models
title: Model Classes
- local: trainer
title: Trainer Classes
title: API
- sections:
- local: sentiment_tuning
title: Sentiment Tuning
Expand Down
16 changes: 16 additions & 0 deletions docs/source/models.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Models

TRL supports various model architectures including most used text generative models.
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved

## PreTrainedModelWrapper

[[autodoc]] PreTrainedModelWrapper

## AutoModelForCausalLMWithValueHead


[[autodoc]] AutoModelForCausalLMWithValueHead
- __init__
- forward
- generate
- _init_weights
12 changes: 12 additions & 0 deletions docs/source/trainer.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Trainer

At TRL we plan to release several RLHF algorithms, we started our journey with PPO (Proximal Policy Optimisation) with an implementation that largely follows the structure introduced in the paper "Fine-Tuning Language Models from Human Preferences" by D. Ziegler et al. [[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)].
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since adding new algorithms is not on the roadmap at the moment maybe let's just focus on PPO :)

We could also add a sentence or two about the classes. E.g. that they are inspired/influence by the transformers.Trainer and are adapted to RL.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Adapted the text in da456cf


## PPOConfig

[[autodoc]] PPOConfig

## PPOTrainer

[[autodoc]] PPOTrainer

2 changes: 1 addition & 1 deletion trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@

__version__ = "0.1.1"

from .models import AutoModelForCausalLMWithValueHead, create_reference_model
from .models import AutoModelForCausalLMWithValueHead, PreTrainedModelWrapper, create_reference_model
from .trainer import PPOConfig, PPOTrainer
71 changes: 47 additions & 24 deletions trl/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,13 @@ class PreTrainedModelWrapper(nn.Module):
(`~transformers.PreTrained`) class in order to keep some attributes and methods of the
(`~transformers.PreTrainedModel`) class.

Attributes
----------
pretrained_model: (`transformers.PreTrainedModel`)
The model to be wrapped.
parent_class: (`transformers.PreTrainedModel`)
The parent class of the model to be wrapped.
supported_args: (`list`)
The list of arguments that are supported by the wrapper class.
Attributes:
pretrained_model: (`transformers.PreTrainedModel`)
The model to be wrapped.
parent_class: (`transformers.PreTrainedModel`)
The parent class of the model to be wrapped.
supported_args: (`list`)
The list of arguments that are supported by the wrapper class.
"""
transformers_parent_class = None
supported_args = None
Expand All @@ -45,20 +44,24 @@ def __init__(self, pretrained_model=None, **kwargs):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""
Instantiates a new model from a pretrained model.

Parameters
----------
pretrained_model_name_or_path: (`str` or `transformers.PreTrainedModel`)
The path to the pretrained model or its name.
*model_args:
Additional positional arguments passed along to the underlying model's
`from_pretrained` method.
**kwargs:
Additional keyword arguments passed along to the underlying model's
`from_pretrained` method. We also pre-process the kwargs to extract
the arguments that are specific to the `transformers.PreTrainedModel`
class and the arguments that are specific to trl models.
Instantiates a new model from a pretrained model from `transformers`. The
pretrained model is loaded using the `from_pretrained` method of the
`transformers.PreTrainedModel` class. The arguments that are specific to the
`transformers.PreTrainedModel` class are passed along this method and filtered
out from the `kwargs` argument.


Args:
pretrained_model_name_or_path (`str` or `transformers.PreTrainedModel`):
The path to the pretrained model or its name.
*model_args (`list`, *optional*)):
Additional positional arguments passed along to the underlying model's
`from_pretrained` method.
**kwargs (`dict`, *optional*):
Additional keyword arguments passed along to the underlying model's
`from_pretrained` method. We also pre-process the kwargs to extract
the arguments that are specific to the `transformers.PreTrainedModel`
class and the arguments that are specific to trl models.
"""
if kwargs is not None:
trl_model_args, pretrained_kwargs = cls._split_kwargs(kwargs)
Expand Down Expand Up @@ -104,13 +107,33 @@ def _split_kwargs(cls, kwargs):

def push_to_hub(self, *args, **kwargs):
r"""
Push the pretrained model to the hub.
Push the pretrained model to the hub. This method is a wrapper around
`transformers.PreTrainedModel.push_to_hub`. Please refer to the documentation
of `transformers.PreTrainedModel.push_to_hub` for more information.

Args:
*args (`list`, *optional*):
Positional arguments passed along to the underlying model's
`push_to_hub` method.
**kwargs (`dict`, *optional*):
Keyword arguments passed along to the underlying model's
`push_to_hub` method.
"""
return self.pretrained_model.push_to_hub(*args, **kwargs)

def save_pretrained(self, *args, **kwargs):
r"""
Save the pretrained model to a directory.
Save the pretrained model to a directory. This method is a wrapper around
`transformers.PreTrainedModel.save_pretrained`. Please refer to the documentation
of `transformers.PreTrainedModel.save_pretrained` for more information.

Args:
*args (`list`, *optional*):
Positional arguments passed along to the underlying model's
`save_pretrained` method.
**kwargs (`dict`, *optional*):
Keyword arguments passed along to the underlying model's
`save_pretrained` method.
"""
return self.pretrained_model.save_pretrained(*args, **kwargs)

Expand Down
77 changes: 66 additions & 11 deletions trl/models/modeling_value_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,27 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
An autoregressive model with a value head in addition to the language model head.
This class inherits from `~trl.PreTrainedModelWrapper` and wraps a
`transformers.PreTrainedModel` class. The wrapper class supports classic functions
such as `from_pretrained` and `push_to_hub` and also provides some additional
functionalities such as `generate`.

Args:
pretrained_model (`transformers.PreTrainedModel`):
The model to wrap. It should be a causal language model such as GPT2.
or any model mapped inside the `AutoModelForCausalLM` class.
kwargs:
Additional keyword arguments passed along to the `ValueHead` class.
such as `from_pretrained`, `push_to_hub` and `generate`. To call a method of the wrapped
model, simply manipulate the `pretrained_model` attribute of this class.

Class attributes:
- **transformers_parent_class** (`transformers.PreTrainedModel`) -- The parent class of the wrapped model. This
should be set to `transformers.AutoModelForCausalLM` for this class.
- **lm_head_namings** (`tuple`) -- A tuple of strings that are used to identify the language model head of the
wrapped model. This is set to `("lm_head", "embed_out")` for this class but can be changed for other models
in the future
- **supported_args** (`tuple`) -- A tuple of strings that are used to identify the arguments that are supported
by the `ValueHead` class. Currently the supported args are:
- **summary_dropout_prob** (`float`, `optional`, defaults to `None`) -- The dropout probability for the
`ValueHead` class.
- **v_head_initializer_range** (`float`, `optional`, defaults to `0.2`) -- The initializer range for the
`ValueHead` if a specific initialization strategy is selected.
- **v_head_init_strategy** (`str`, `optional`, defaults to `None`) -- The initialization strategy for the
`ValueHead`. Currently supported strategies are:
- **`None`** -- Initializes the weights of the `ValueHead` with a random distribution. This is the default
strategy.
- **"normal"** -- Initializes the weights of the `ValueHead` with a normal distribution.

"""
transformers_parent_class = AutoModelForCausalLM
lm_head_namings = ["lm_head", "embed_out"]
Expand All @@ -78,6 +90,16 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
)

def __init__(self, pretrained_model, **kwargs):
r"""
Initializes the model.

Args:
pretrained_model (`transformers.PreTrainedModel`):
The model to wrap. It should be a causal language model such as GPT2.
or any model mapped inside the `AutoModelForCausalLM` class.
kwargs (`dict`, `optional`):
Additional keyword arguments, that are passed to the `ValueHead` class.
"""
super().__init__(pretrained_model)
v_head_kwargs, _ = self._split_kwargs(kwargs)

Expand All @@ -90,7 +112,16 @@ def __init__(self, pretrained_model, **kwargs):

def _init_weights(self, **kwargs):
r"""
We initialize the weights of the value head.
Initializes the weights of the value head. The default initialization strategy is random.
Users can pass a different initialization strategy by passing the `v_head_init_strategy` argument
when calling `.from_pretrained`. Supported strategies are:
- `normal`: initializes the weights with a normal distribution.

Args:
**kwargs (`dict`, `optional`):
Additional keyword arguments, that are passed to the `ValueHead` class. These arguments
can contain the `v_head_init_strategy` argument as well as the `v_head_initializer_range`
argument.
"""
initializer_range = kwargs.pop("v_head_initializer_range", 0.2)
# random init by default
Expand All @@ -109,6 +140,22 @@ def forward(
attention_mask=None,
**kwargs,
):
r"""
Applies a forward pass to the wrapped model and returns the logits of the value head.

Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
past_key_values (`tuple(tuple(torch.FloatTensor))`, `optional`):
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `past_key_values` input) to speed up sequential decoding.
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
kwargs (`dict`, `optional`):
Additional keyword arguments, that are passed to the wrapped model.
"""
base_model_output = self.pretrained_model(
input_ids=input_ids,
past_key_values=past_key_values,
Expand All @@ -127,6 +174,14 @@ def forward(

def generate(self, *args, **kwargs):
r"""
We call `generate` on the wrapped model.
A simple wrapper around the `generate` method of the wrapped model.
Please refer to the [`generate`](https://huggingface.co/docs/transformers/internal/generation_utils)
method of the wrapped model for more information about the supported arguments.

Args:
*args (`list`, *optional*):
Positional arguments passed to the `generate` method of the wrapped model.
**kwargs (`dict`, *optional*):
Keyword arguments passed to the `generate` method of the wrapped model.
"""
return self.pretrained_model.generate(*args, **kwargs)
54 changes: 33 additions & 21 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,35 @@
whiten,
)
from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper, create_reference_model
from . import AdaptiveKLController, BaseTrainer, FixedKLController
from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig


class PPOTrainer(BaseTrainer):
"""
The PPOTrainer uses Proximal Policy Optimization to optimise language models.

Attributes:
- **config** (`PPOConfig`) -- Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more details.
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
- **model** (`PreTrainedModelWrapper`) -- Model to be optimized, Hugging Face transformer model with a value head.
Check the documentation of `PreTrainedModelWrapper` for more details.
- **ref_model** (`PreTrainedModelWrapper`, *optional*) -- Reference model to be used for KL penalty, Hugging Face transformer model with a casual language modelling head.
Check the documentation of `PreTrainedModelWrapper` for more details. If no reference model is provided, the
trainer will create a reference model with the same architecture as the model to be optimized with shared layers.
- **tokenizer** (`Union[PreTrainedTokenizer, PreTrainedTokenizerFast]`) -- Tokenizer to be used for encoding the data. Check the documentation of `transformers.PreTrainedTokenizer` and
`transformers.PreTrainedTokenizerFast` for more details.
- **dataset** (Union[`torch.utils.data.Dataset`, `datasets.Dataset`], *optional*) -- PyTorch dataset or Hugging Face dataset. This is used to create a PyTorch dataloader. If no dataset is provided,
the dataloader must be created outside the trainer users needs to design their own dataloader and make sure the batch
size that is used is the same as the one specified in the configuration object.
- **optimizer** (`torch.optim.Optimizer`, *optional*) -- Optimizer to be used for training. If no optimizer is provided, the trainer will create an Adam optimizer with
the learning rate specified in the configuration object.
- **data_collator** (DataCollatorForLanguageModeling, *optional*) -- Data collator to be used for training and passed along the dataloader
- **num_shared_layers** (int, *optional*) -- Number of layers to be shared between the model and the reference model, if no reference model is passed. If no number is provided, all the layers
will be shared.
"""

def __init__(
self,
config,
config: PPOConfig,
model: PreTrainedModelWrapper,
ref_model: PreTrainedModelWrapper,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
Expand Down Expand Up @@ -167,8 +185,7 @@ def prepare_dataloader(self, dataset: Union[torch.utils.data.Dataset, Dataset],
Data collator function.

Returns:
`torch.utils.data.DataLoader`:
PyTorch dataloader
`torch.utils.data.DataLoader`: PyTorch dataloader
"""
if isinstance(dataset, Dataset):
dataset = self._remove_unused_columns(dataset)
Expand Down Expand Up @@ -210,7 +227,8 @@ def _remove_unused_columns(self, dataset: "Dataset"):

def generate(self, query_tensor: torch.Tensor, **generation_kwargs):
"""
Generate response given query.
Generate response given the query tensor. First unwrap the model from the accelerator and then
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
call the `generate` method of the model.

Args:
query_tensor (`torch.LongTensor`):
Expand All @@ -219,8 +237,7 @@ def generate(self, query_tensor: torch.Tensor, **generation_kwargs):
Keyword arguments for generation.

Returns:
response (`torch.LongTensor`):
A tensor of shape (`batch_size`, `gen_len`) containing response tokens.
`torch.LongTensor`: A tensor of shape (`batch_size`, `gen_len`) containing response tokens.
"""
response = self.accelerator.unwrap_model(self.model).generate(
query_tensor.unsqueeze(dim=0), **generation_kwargs
Expand Down Expand Up @@ -248,8 +265,7 @@ def _step_safety_checker(
scores (List[`torch.FloatTensor`]):
List of tensors containing the scores.
Returns:
queries, responses, scores (List[`torch.LongTensor`], List[`torch.LongTensor`], List[`torch.FloatTensor`]):
The input processed data.
`tuple`: The input processed data.
"""
for name, tensor_list in zip(["queries", "responses", "scores"], [queries, responses, scores]):
if not isinstance(tensor_list, list):
Expand Down Expand Up @@ -282,7 +298,8 @@ def step(
scores: List[torch.FloatTensor],
):
"""
Run a PPO optimisation step.
Run a PPO optimisation step given the input data. The input data is first checked for validity
and then the forward pass is run.
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved

Args:
queries (List[`torch.LongTensor`]):
Expand All @@ -293,8 +310,7 @@ def step(
List of tensors containing the scores.

Returns:
train_stats (dict[str, Any]):
a summary of the training statistics
`dict[str, Any]`: A summary of the training statistics
"""

bs = self.config.batch_size
Expand Down Expand Up @@ -370,8 +386,7 @@ def gather_stats(self, stats):
a dictionary of stats to be gathered. The stats should contain torch tensors.

Returns:
stats (dict[str, Any]):
a dictionary of stats with the tensors gathered.
`dict[str, Any]`: A dictionary of stats with the tensors gathered.
"""
import torch.distributed as dist

Expand All @@ -396,13 +411,10 @@ def batched_forward_pass(self, queries: torch.Tensor, responses: torch.Tensor):
List of tensors containing the encoded responses, shape (`batch_size`, `response_length`)

Returns:
all_logprobs (`torch.FloatTensor`):
List of tensors containing the logprobs, shape (`batch_size`, `response_length`)
all_ref_logprobs (`torch.FloatTensor`):
List of tensors containing the logprobs from the reference model, shape (`batch_size`, `response_length`)
all_values (`torch.FloatTensor`):
List of tensors containing the output from the value head, shape (`batch_size`, `response_length`)

(tuple):
- all_logprobs (`torch.FloatTensor`): Log probabilities of the responses, shape (`batch_size`, `response_length`)
- all_ref_logprobs (`torch.FloatTensor`): Log probabilities of the responses, shape (`batch_size`, `response_length`)
- all_values (`torch.FloatTensor`): Values of the responses, shape (`batch_size`, `response_length`)
"""
bs = self.config.batch_size
fbs = self.config.forward_batch_size
Expand Down