Skip to content

Commit

Permalink
⚠️ Add warning guidelines and update codebase to follow best practices (
Browse files Browse the repository at this point in the history
#2350)

* Add guidelines for working with warnings in the codebase

* Remove unnecessary warnings and improve code initialization

* Fix warnings and improve accuracy calculation

* Add rich library dependency for text formatting

* Update LoRA weight loading warning message

* Fix logging and import issues in AlignPropConfig

* Fix warnings and improve code readability

* Remove unused import statements

* Refactor CPOTrainer class in cpo_trainer.py

* Remove unnecessary warnings and raise ValueError for missing model

* Fix warnings and improve code consistency

* Update CONTRIBUTING.md to clarify the purpose of warnings

* Fix string formatting in DataCollatorForCompletionOnlyLM class

* Update SimPO loss parameters in CPOTrainer

* Fix warnings and remove unnecessary code in ConstantLengthDataset class

* Clarify warning guidelines

* Rewrite the entire section

* Fix capitalization in CONTRIBUTING.md

* Fix formatting in CONTRIBUTING.md
  • Loading branch information
qgallouedec authored Nov 29, 2024
1 parent 8d9cfaa commit d6a8f2c
Show file tree
Hide file tree
Showing 20 changed files with 161 additions and 235 deletions.
53 changes: 53 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,56 @@ The deprecation and removal schedule is based on each feature's usage and impact
- **Widely-Used Components**: For a feature with high usage, we aim for a more gradual transition period of approximately **5 months**, generally scheduling deprecation around **5 minor releases** after the initial warning.

These examples represent the two ends of a continuum. The specific timeline for each feature will be determined individually, balancing innovation with user stability needs.

### Working with warnings

Warnings play a critical role in guiding users toward resolving potential issues, but they should be used thoughtfully to avoid unnecessary noise. Unlike logging, which provides informational context or operational details, warnings signal conditions that require attention and action. Overusing warnings can dilute their importance, leading users to ignore them entirely.

#### Definitions

- **Correct**: An operation is correct if it is valid, follows the intended approach, and aligns with the current best practices or guidelines within the codebase. This is the recommended or intended way to perform the operation.
- **Supported**: An operation is supported if it is technically valid and works within the current codebase, but it may not be the most efficient, optimal, or recommended way to perform the task. This includes deprecated features or legacy approaches that still work but may be phased out in the future.

#### Choosing the right message

- **Correct → No warning**:
If the operation is fully valid and expected, no message should be issued. The system is working as intended, so no warning is necessary.

- **Correct but deserves attention → No warning, possibly a log message**:
When an operation is correct but uncommon or requires special attention, providing an informational message can be helpful. This keeps users informed without implying any issue. If available, use the logger to output this message. Example:

```python
logger.info("This is an informational message about a rare but correct operation.")
```

- **Correct but very likely a mistake → Warning with option to disable**:
In rare cases, you may want to issue a warning for a correct operation that’s very likely a mistake. In such cases, you must provide an option to suppress the warning. This can be done with a flag in the function. Example:

```python
def my_function(foo, bar, _warn=True):
if foo == bar:
if _warn:
warnings.warn("foo and bar are the same, this is likely a mistake. Ignore this warning by setting `_warn=False`.")
# Do something
```
- **Supported but not correct → Warning**:
If the operation is technically supported but is deprecated, suboptimal, or could cause future issues (e.g., conflicting arguments), a warning should be raised. This message should be actionable, meaning it must explain how to resolve the issue. Example:
```python
def my_function(foo, bar):
if foo and bar:
warnings.warn("Both `foo` and `bar` were provided, but only one is allowed. Ignoring `foo`. Please pass only one of these arguments.")
# Do something
```
- **Not supported → Exception**:
If the operation is invalid or unsupported, raise an exception. This indicates that the operation cannot be performed and requires immediate attention. Example:
```python
def my_function(foo, bar):
if foo and bar:
raise ValueError("Both `foo` and `bar` were provided, but only one is allowed. Please pass only one of these arguments.")
```
By following this classification, you ensure that warnings, information, and exceptions are used appropriately, providing clear guidance to the user without cluttering the system with unnecessary messages.
2 changes: 1 addition & 1 deletion docs/source/cpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ While training and evaluating we record the following reward metrics:

### Simple Preference Optimization (SimPO)

The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the [`CPOTrainer`]. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, we can use SimPO easily by turning on `loss_type="simpo"` and `cpo_alpha=0` in the [`CPOConfig`].
The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the [`CPOTrainer`]. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, we can use SimPO easily by turning on `loss_type="simpo"` and `cpo_alpha=0.0` in the [`CPOConfig`].

### CPO-SimPO

Expand Down
3 changes: 2 additions & 1 deletion examples/scripts/reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@
if model_config.use_peft and model_config.lora_task_type != "SEQ_CLS":
warnings.warn(
"You are using a `task_type` that is different than `SEQ_CLS` for PEFT. This will lead to silent bugs"
" Make sure to pass --lora_task_type SEQ_CLS when using this script with PEFT."
" Make sure to pass --lora_task_type SEQ_CLS when using this script with PEFT.",
UserWarning,
)

##############
Expand Down
3 changes: 2 additions & 1 deletion trl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,8 @@ def randn_tensor(
warnings.warn(
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
f" slighly speed up this function by passing a generator that was created on the {device} device."
f" slighly speed up this function by passing a generator that was created on the {device} device.",
UserWarning,
)
elif gen_device_type != device.type and gen_device_type == "cuda":
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
Expand Down
19 changes: 12 additions & 7 deletions trl/environment/base_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import re
import warnings
from typing import Optional

import torch
Expand Down Expand Up @@ -145,8 +144,10 @@ def show_text(self, show_legend=False):
Print the text history.
"""
if not is_rich_available():
warnings.warn("install rich to display text")
return
raise ImportError(
"The `rich` library is required to display text with formatting. "
"Install it using `pip install rich`."
)

text = Text(self.text)
text.stylize(self.prompt_color, self.text_spans[0][0], self.text_spans[1][0])
Expand All @@ -167,8 +168,10 @@ def show_tokens(self, tokenizer, show_legend=False):
Print the history tokens.
"""
if not is_rich_available():
warnings.warn("install rich to display tokens")
return
raise ImportError(
"The `rich` library is required to display tokens with formatting. "
"Install it using `pip install rich`."
)

text = Text()
prompt_end = self.token_spans[0][1]
Expand All @@ -192,8 +195,10 @@ def show_colour_legend(self):
Print the colour legend.
"""
if not is_rich_available():
warnings.warn("install rich to display colour legend")
return
raise ImportError(
"The `rich` library is required to display colour legends with formatting. "
"Install it using `pip install rich`."
)
text = Text("\n\n(Colour Legend: ")
text.append("Prompt", style=self.prompt_color)
text.append("|")
Expand Down
5 changes: 3 additions & 2 deletions trl/models/modeling_sd_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,8 +808,9 @@ def __init__(self, pretrained_model_name: str, *, pretrained_model_revision: str
except OSError:
if use_lora:
warnings.warn(
"If you are aware that the pretrained model has no lora weights to it, ignore this message. "
"Otherwise please check the if `pytorch_lora_weights.safetensors` exists in the model folder."
"Trying to load LoRA weights but no LoRA weights found. Set `use_lora=False` or check that "
"`pytorch_lora_weights.safetensors` exists in the model folder.",
UserWarning,
)

self.sd_pipeline.scheduler = DDIMScheduler.from_config(self.sd_pipeline.scheduler.config)
Expand Down
11 changes: 1 addition & 10 deletions trl/trainer/alignprop_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@

import os
import sys
import warnings
from dataclasses import dataclass, field
from typing import Any, Literal, Optional

from transformers import is_bitsandbytes_available, is_torchvision_available
from transformers import is_bitsandbytes_available

from ..core import flatten_dict

Expand Down Expand Up @@ -139,14 +138,6 @@ def to_dict(self):
return flatten_dict(output_dict)

def __post_init__(self):
if self.log_with not in ["wandb", "tensorboard"]:
warnings.warn(
"Accelerator tracking only supports image logging if `log_with` is set to 'wandb' or 'tensorboard'."
)

if self.log_with == "wandb" and not is_torchvision_available():
warnings.warn("Wandb image logging requires torchvision to be installed")

if self.train_use_8bit_adam and not is_bitsandbytes_available():
raise ImportError(
"You need to install bitsandbytes to use 8bit Adam. "
Expand Down
34 changes: 7 additions & 27 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,17 +394,9 @@ def __init__(
ref_model_init_kwargs["torch_dtype"] = torch_dtype

if isinstance(model, str):
warnings.warn(
"You passed a model_id to the BCOTrainer. This will automatically create an "
"`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you."
)
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)

if isinstance(ref_model, str):
warnings.warn(
"You passed a ref model_id to the BCOTrainer. This will automatically create an "
"`AutoModelForCausalLM`"
)
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)

# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
Expand Down Expand Up @@ -573,8 +565,11 @@ def make_inputs_require_grad(module, input, output):
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
warnings.warn(
"You set `output_router_logits` to True in the model config, but `router_aux_loss_coef` is set to 0.0,"
" meaning the auxiliary loss will not be used."
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
"loss.",
UserWarning,
)

# Underlying Distribution Matching argument
Expand Down Expand Up @@ -714,7 +709,6 @@ def make_inputs_require_grad(module, input, output):
self.running = RunningMoments(accelerator=self.accelerator)

if self.embedding_func is None:
warnings.warn("You did not pass `embedding_func` underlying distribution matching feature is deactivated.")
return

chosen_embeddings = self._get_sample_prompt_embeddings(desirable, sample_size=self.args.prompt_sample_size)
Expand Down Expand Up @@ -884,16 +878,12 @@ def _load_optimizer_and_scheduler(self, checkpoint):
return
# when loading optimizer and scheduler from checkpoint, also load the running delta object.
running_file = os.path.join(checkpoint, RUNNING_NAME)
if not os.path.isfile(running_file):
warnings.warn(f"Missing file {running_file}. Will use a new running delta value for BCO loss calculation")
else:
if os.path.isfile(running_file):
self.running = RunningMoments.load_from_json(self.accelerator, running_file)

if self.match_underlying_distribution:
clf_file = os.path.join(checkpoint, CLF_NAME)
if not os.path.isfile(running_file):
warnings.warn(f"Missing file {clf_file}. Will use a new UDM classifier for BCO loss calculation")
else:
if os.path.isfile(running_file):
self.clf.set_params(**torch.load(clf_file, weights_only=True, map_location="cpu"))

@contextmanager
Expand Down Expand Up @@ -1278,11 +1268,6 @@ def compute_loss(
return_outputs=False,
num_items_in_batch=None,
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
if not self.use_dpo_data_collator:
warnings.warn(
"compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()

with compute_loss_context_manager:
Expand Down Expand Up @@ -1359,11 +1344,6 @@ def prediction_step(
prediction_loss_only: bool,
ignore_keys: Optional[list[str]] = None,
):
if not self.use_dpo_data_collator:
warnings.warn(
"prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)
if ignore_keys is None:
if hasattr(model, "config"):
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
Expand Down
33 changes: 8 additions & 25 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,6 @@ def __init__(
model_init_kwargs["torch_dtype"] = torch_dtype

if isinstance(model, str):
warnings.warn(
"You passed a model_id to the CPOTrainer. This will automatically create an "
"`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you."
)
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)

# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
Expand Down Expand Up @@ -290,7 +286,9 @@ def make_inputs_require_grad(module, input, output):

if args.loss_type in ["hinge", "ipo"] and args.label_smoothing > 0:
warnings.warn(
"You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter."
f"You are using the {args.loss_type} loss type that does not support label smoothing. The "
"`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this warning.",
UserWarning,
)
if args.loss_type == "kto_pair":
raise ValueError("Support for kto_pair has been removed in CPOTrainer. Please use KTOTrainer.")
Expand All @@ -303,19 +301,15 @@ def make_inputs_require_grad(module, input, output):
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
warnings.warn(
"You set `output_router_logits` to True in the model config, but `router_aux_loss_coef` is set to 0.0,"
" meaning the auxiliary loss will not be used."
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
"loss.",
UserWarning,
)

if args.loss_type == "simpo":
self.simpo_gamma = args.simpo_gamma
if self.cpo_alpha > 0:
warnings.warn(
"You are using CPO-SimPO method because you set a non-zero cpo_alpha. "
"This will result in the CPO-SimPO method "
"(https://github.com/fe1ixxu/CPO_SIMPO/tree/main). "
"If you want to use a pure SimPO method, please set cpo_alpha to 0."
)

self._stored_metrics = defaultdict(lambda: defaultdict(list))

Expand Down Expand Up @@ -845,12 +839,6 @@ def compute_loss(
return_outputs=False,
num_items_in_batch=None,
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
if not self.use_dpo_data_collator:
warnings.warn(
"compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)

compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()

with compute_loss_context_manager:
Expand Down Expand Up @@ -891,11 +879,6 @@ def prediction_step(
prediction_loss_only: bool,
ignore_keys: Optional[list[str]] = None,
):
if not self.use_dpo_data_collator:
warnings.warn(
"prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)
if ignore_keys is None:
if hasattr(model, "config"):
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
Expand Down
11 changes: 1 addition & 10 deletions trl/trainer/ddpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@

import os
import sys
import warnings
from dataclasses import dataclass, field
from typing import Literal, Optional

from transformers import is_bitsandbytes_available, is_torchvision_available
from transformers import is_bitsandbytes_available

from ..core import flatten_dict

Expand Down Expand Up @@ -167,14 +166,6 @@ def to_dict(self):
return flatten_dict(output_dict)

def __post_init__(self):
if self.log_with not in ["wandb", "tensorboard"]:
warnings.warn(
"Accelerator tracking only supports image logging if `log_with` is set to 'wandb' or 'tensorboard'."
)

if self.log_with == "wandb" and not is_torchvision_available():
warnings.warn("Wandb image logging requires torchvision to be installed")

if self.train_use_8bit_adam and not is_bitsandbytes_available():
raise ImportError(
"You need to install bitsandbytes to use 8bit Adam. "
Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ class DPOConfig(TrainingArguments):
def __post_init__(self):
if self.max_target_length is not None:
warnings.warn(
"The `max_target_length` argument is deprecated in favor of `max_completion_length` and will be removed in a future version.",
"The `max_target_length` argument is deprecated in favor of `max_completion_length` and will be "
"removed in v0.14.",
FutureWarning,
)
if self.max_completion_length is None:
Expand Down
Loading

0 comments on commit d6a8f2c

Please sign in to comment.