Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improves formatting of docstring + newlines #2006

Merged
merged 13 commits into from
Sep 9, 2024
1 change: 0 additions & 1 deletion trl/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
"cli_utils": ["SFTScriptArguments", "init_zero_verbose", "DPOScriptArguments", "TrlParser", "YamlConfigParser"],
}


if TYPE_CHECKING:
from .cli_utils import SFTScriptArguments, init_zero_verbose, DPOScriptArguments, TrlParser, YamlConfigParser
else:
Expand Down
2 changes: 1 addition & 1 deletion trl/commands/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def parse_and_set_env(self, config_path):
return config

def to_string(self, config):
final_string = """"""
final_string = ""
for key, value in config.items():
if isinstance(value, (dict, list)):
if len(value) != 0:
Expand Down
1 change: 1 addition & 0 deletions trl/environment/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import TYPE_CHECKING
from ..import_utils import _LazyModule


_import_structure = {
"base_environment": ["TextEnvironment", "TextHistory"],
}
Expand Down
6 changes: 3 additions & 3 deletions trl/environment/base_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(self, text, tokens, system=True):
"""
Initialize TextHistory.

args:
Args:
text (`str`): The text of the first segment.
tokens (`torch.LongTensor`): The tokens of the first segment.
system (`bool`, *optional*): Whether the first segment is a system or user segment.
Expand All @@ -90,7 +90,7 @@ def append_segment(self, text, tokens, system=True):
"""
Append a new segment to the history.

args:
Args:
text (`str`): The text of the new segment.
tokens (`torch.LongTensor`): The tokens of the new segment.
system (`bool`, *optional*): Whether the new segment is a system or user segment.
Expand Down Expand Up @@ -422,7 +422,7 @@ def _generate_batched(
"""
Generate responses for a list of query tensors.

args:
Args:
query_tensors (list[torch.Tensor]): A list of query tensors to generate responses for.
batch_size (int): The batch size to use for generation.
pad_to_multiple_of (int): The padding length to use for generation.
Expand Down
9 changes: 4 additions & 5 deletions trl/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ class PreTrainedModelWrapper(nn.Module):
(`~transformers.PreTrainedModel`) class.

Attributes:
pretrained_model: (`transformers.PreTrainedModel`)
pretrained_model (`transformers.PreTrainedModel`):
The model to be wrapped.
parent_class: (`transformers.PreTrainedModel`)
parent_class (`transformers.PreTrainedModel`):
The parent class of the model to be wrapped.
supported_args: (`list`)
supported_args (`list`):
The list of arguments that are supported by the wrapper class.
"""

Expand Down Expand Up @@ -118,7 +118,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
`transformers.PreTrainedModel` class are passed along this method and filtered
out from the `kwargs` argument.


Args:
pretrained_model_name_or_path (`str` or `transformers.PreTrainedModel`):
The path to the pretrained model or its name.
Expand Down Expand Up @@ -617,7 +616,7 @@ def create_reference_model(
pattern (`str`, *optional*): The shared layers are selected with a string pattern
(e.g. "transformer.h.{layer}" for GPT2) and if a custom pattern is necessary it can be passed here.

Returns
Returns:
`PreTrainedModelWrapper`
"""
if is_deepspeed_zero3_enabled():
Expand Down
8 changes: 7 additions & 1 deletion trl/models/modeling_sd_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,13 @@ def pipeline_step_with_grad(
guidance_rescale: float = 0.0,
):
r"""
Function to get RGB image with gradients attached to the model weights. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to pipeline.unet.config.sample_size * pipeline.vae_scale_factor): The height in pixels of the generated image.
Function to get RGB image with gradients attached to the model weights.

Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` instead.
height (`int`, *optional*, defaults to pipeline.unet.config.sample_size * pipeline.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to pipeline.unet.config.sample_size * pipeline.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
Expand Down
1 change: 0 additions & 1 deletion trl/models/modeling_value_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
- **`None`** -- Initializes the weights of the `ValueHead` with a random distribution. This is the default
strategy.
- **"normal"** -- Initializes the weights of the `ValueHead` with a normal distribution.

"""

transformers_parent_class = AutoModelForCausalLM
Expand Down
2 changes: 0 additions & 2 deletions trl/models/sd_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,11 @@ class StateDictType(enum.Enum):
".to_out_lora.down": ".out_proj.lora_linear_layer.down",
}


DIFFUSERS_STATE_DICT_MAPPINGS = {
StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_DIFFUSERS,
StateDictType.PEFT: PEFT_TO_DIFFUSERS,
}


KEYS_TO_ALWAYS_REPLACE = {
".processor.": ".",
}
Expand Down
14 changes: 7 additions & 7 deletions trl/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
AutoModelForSeq2SeqLMWithValueHead,
)


if is_deepspeed_available():
import deepspeed

Expand Down Expand Up @@ -72,13 +71,14 @@ def setup_chat_format(
Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens.

Args:
model (`~transformers.PreTrainedModel`): The model to be modified.
tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified.
format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml".
resize_to_multiple_of (`Optional[int]`): Number to resize the embedding layer to. Defaults to None.
model (`~transformers.PreTrainedModel`): The model to be modified.
tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified.
format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml".
resize_to_multiple_of (`Optional[int]`): Number to resize the embedding layer to. Defaults to None.

Returns:
model (`~transformers.PreTrainedModel`): The modified model.
tokenizer (`~transformers.PreTrainedTokenizer`): The modified tokenizer.
model (`~transformers.PreTrainedModel`): The modified model.
tokenizer (`~transformers.PreTrainedTokenizer`): The modified tokenizer.
"""
# check if format available and retrieve
if format not in FORMAT_MAPPING:
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import TYPE_CHECKING
from ..import_utils import _LazyModule, is_diffusers_available, OptionalDependencyNotAvailable


_import_structure = {
"callbacks": ["RichProgressCallback", "SyncRefModelCallback"],
"utils": [
Expand Down Expand Up @@ -74,7 +75,6 @@
else:
_import_structure["ddpo_trainer"] = ["DDPOTrainer"]


if TYPE_CHECKING:
# isort: off
from .callbacks import RichProgressCallback, SyncRefModelCallback
Expand Down
19 changes: 11 additions & 8 deletions trl/trainer/alignprop_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

logger = get_logger(__name__)


MODEL_CARD_TEMPLATE = """---
license: apache-2.0
library_name: transformers
Expand All @@ -56,12 +55,16 @@ class AlignPropTrainer(BaseTrainer):
As of now only Stable Diffusion based pipelines are supported

Attributes:
**config** (`AlignPropConfig`) -- Configuration object for AlignPropTrainer. Check the documentation of `PPOConfig` for more
details.
**reward_function** (Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor]) -- Reward function to be used
**prompt_function** (Callable[[], Tuple[str, Any]]) -- Function to generate prompts to guide model
**sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training.
**image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images
config (`AlignPropConfig`):
Configuration object for AlignPropTrainer. Check the documentation of `PPOConfig` for more details.
reward_function (Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor]):
qgallouedec marked this conversation as resolved.
Show resolved Hide resolved
Reward function to be used
prompt_function (Callable[[], Tuple[str, Any]]):
Function to generate prompts to guide model
sd_pipeline (`DDPOStableDiffusionPipeline`):
Stable Diffusion pipeline to be used for training.
image_samples_hook (Optional[Callable[[Any, Any, Any], Any]]):
Hook to be called to log images
"""

_tag_names = ["trl", "alignprop"]
Expand Down Expand Up @@ -215,7 +218,6 @@ def step(self, epoch: int, global_step: int):

Returns:
global_step (int): The updated global step.

"""
info = defaultdict(list)

Expand Down Expand Up @@ -281,6 +283,7 @@ def calculate_loss(self, rewards):
Args:
rewards (torch.Tensor):
Differentiable reward scalars for each generated image, shape: [batch_size]

Returns:
loss (torch.Tensor)
(all of these are of shape (1,))
Expand Down
7 changes: 3 additions & 4 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
if is_peft_available():
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training


if is_wandb_available():
import wandb

Expand Down Expand Up @@ -148,11 +147,11 @@ def _process_tokens(example: Dict[str, Any], model: "PreTrainedModel" = None, **

At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
in case the prompt + completion responses is/are too long. First
we truncate the prompt; if we're still too long, we truncate the completion.
we truncate the prompt; if we're still too long, we truncate the completion.

We also create the labels for the completion responses, which are of length equal to
the sum of the length of the prompt and the completion response, with
label_pad_token_id for the prompt tokens.
the sum of the length of the prompt and the completion response, with
label_pad_token_id for the prompt tokens.
"""
prompt = example["prompt"]
completion = example["completion"]
Expand Down
6 changes: 3 additions & 3 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,11 +378,11 @@ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module

At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
in case the prompt + chosen or prompt + rejected responses is/are too long. First
we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
we truncate the prompt; if we're still too long, we truncate the chosen/rejected.

We also create the labels for the chosen/rejected responses, which are of length equal to
the sum of the length of the prompt and the chosen/rejected response, with
label_pad_token_id for the prompt tokens.
the sum of the length of the prompt and the chosen/rejected response, with
label_pad_token_id for the prompt tokens.
"""
batch = {}
prompt = feature["prompt"]
Expand Down
1 change: 0 additions & 1 deletion trl/trainer/ddpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

logger = get_logger(__name__)


MODEL_CARD_TEMPLATE = """---
license: apache-2.0
library_name: transformers
Expand Down
39 changes: 25 additions & 14 deletions trl/trainer/iterative_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,31 @@ class IterativeSFTTrainer(Trainer):
"""
The IterativeSFTTrainer can be used to finetune models with methods that requires some steps between optimization.

Attributes:
**model** (`PreTrainedModel`) -- Model to be optimized, either an 'AutoModelForCausalLM' or an 'AutoModelForSeq2SeqLM'.
Args:
model (`PreTrainedModel`):
Model to be optimized, either an 'AutoModelForCausalLM' or an 'AutoModelForSeq2SeqLM'.
Check the documentation of `PreTrainedModel` for more details.
**args** (`transformers.TrainingArguments`): -- The arguments to use for training.
**tokenizer** (`PreTrainedTokenizerBase`) -- Tokenizer to be used for encoding the
data. Check the documentation of `transformers.PreTrainedTokenizer` and
args (`transformers.TrainingArguments`):
The arguments to use for training.
tokenizer (`PreTrainedTokenizerBase`):
Tokenizer to be used for encoding the data. Check the documentation of `transformers.PreTrainedTokenizer` and
`transformers.PreTrainedTokenizerFast` for more details.
**optimizers** (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): -- The optimizer and scheduler to use for training.
**data_collator** (Union[DataCollatorForLanguageModeling, DataCollatorForSeq2Seq], *optional*) -- Data collator to be used for training and
passed along the dataloader.
**eval_dataset** (`datasets.Dataset`): The dataset to use for evaluation.
**max_length** (`int`, defaults to `None`): -- The maximum length of the input.
**truncation_mode** (`str`, defaults to `keep_end`): -- The truncation mode to use, either `keep_end` or `keep_start`.
**preprocess_logits_for_metrics** (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): -- The function to use to preprocess the logits before computing the metrics.
**compute_metrics** (`Callable[[EvalPrediction], Dict]`, *optional*): -- The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to metric values.
**optimize_device_cache ** (`bool`, *optional*, defaults to `False`) -- Optimize CUDA cache for slightly more memory-efficient training.
optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
The optimizer and scheduler to use for training.
data_collator (Union[DataCollatorForLanguageModeling, DataCollatorForSeq2Seq], *optional*):
Data collator to be used for training and passed along the dataloader.
eval_dataset (`datasets.Dataset`):
The dataset to use for evaluation.
max_length (`int`, defaults to `None`):
The maximum length of the input.
truncation_mode (`str`, defaults to `keep_end`):
The truncation mode to use, either `keep_end` or `keep_start`.
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
The function to use to preprocess the logits before computing the metrics.
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to metric values.
optimize_device_cache (`bool`, *optional*, defaults to `False`):
Optimize CUDA cache for slightly more memory-efficient training.
"""

_tag_names = ["trl", "iterative-sft"]
Expand Down Expand Up @@ -199,6 +208,7 @@ def _step_safety_checker(
List of string containing the text input.
texts_labels (List[`str`]):
List of string containing the text labels.

Returns:
`tuple`: The input data.
"""
Expand Down Expand Up @@ -252,6 +262,7 @@ def step(
List of strings containing the text input (if not provided, input_ids will directly be used)
texts_labels (List[`str`], *optional*):
List of strings containing the text labels (if set to None, will default to text)

Returns:
`dict[str, Any]`: A summary of the training statistics
"""
Expand Down
6 changes: 3 additions & 3 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,11 @@ def _process_tokens(example: Dict[str, Any], model: "PreTrainedModel" = None, **

At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
in case the prompt + completion responses is/are too long. First
we truncate the prompt; if we're still too long, we truncate the completion.
we truncate the prompt; if we're still too long, we truncate the completion.

We also create the labels for the completion responses, which are of length equal to
the sum of the length of the prompt and the completion response, with
label_pad_token_id for the prompt tokens.
the sum of the length of the prompt and the completion response, with
label_pad_token_id for the prompt tokens.
"""
prompt = example["prompt"]
completion = example["completion"]
Expand Down
6 changes: 3 additions & 3 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,11 +395,11 @@ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module

At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
in case the prompt + chosen or prompt + rejected responses is/are too long. First
we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
we truncate the prompt; if we're still too long, we truncate the chosen/rejected.

We also create the labels for the chosen/rejected responses, which are of length equal to
the sum of the length of the prompt and the chosen/rejected response, with
label_pad_token_id for the prompt tokens.
the sum of the length of the prompt and the chosen/rejected response, with
label_pad_token_id for the prompt tokens.
"""
batch = {}
prompt = feature["prompt"]
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,7 @@ def _step_safety_checker(
List of tensors containing the scores.
masks (List[`torch.LongTensor`], *optional*):
list of optional tensors containing the masks of shape (`response_length`)

Returns:
`tuple`: The input processed data.
"""
Expand Down Expand Up @@ -988,6 +989,7 @@ def batched_forward_pass(
List of tensors containing the encoded responses, shape (`batch_size`, `response_length`)
return_logits (`bool`, *optional*, defaults to `False`):
Whether to return all_logits. Set to `False` if logits are not needed to reduce memory consumption.

Returns:
(tuple):
- all_logprobs (`torch.FloatTensor`): Log probabilities of the responses,
Expand Down
Loading