Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pre-commit] update pre-commit yaml #2002

Merged
merged 3 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.0
rev: v0.6.3
hooks:
- id: ruff
types_or: [ python, pyi ]
args: [ --fix ]
- id: ruff-format
types_or: [ python, pyi ]

# - repo: https://github.com/codespell-project/codespell
# rev: v2.1.0
Expand Down
1 change: 1 addition & 0 deletions examples/scripts/alignprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
--log_with="wandb"

"""

from dataclasses import dataclass, field

import numpy as np
Expand Down
1 change: 1 addition & 0 deletions examples/scripts/ddpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
--tracker_project_name="stable_diffusion_training" \
--log_with="wandb"
"""

import os
from dataclasses import dataclass, field

Expand Down
1 change: 1 addition & 0 deletions examples/scripts/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
python examples/scripts/ppo.py \
--log_with=wandb
"""

from dataclasses import dataclass, field
from typing import Optional

Expand Down
1 change: 1 addition & 0 deletions examples/scripts/reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
--eval_steps=500 \
--max_length=512 \
"""

import warnings

import torch
Expand Down
1 change: 1 addition & 0 deletions scripts/stale.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Script to close stale issue. Taken in part from the AllenNLP repository.
https://github.com/allenai/allennlp.
"""

import os
from datetime import datetime as dt
from datetime import timezone
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
""" trl is an open library for RL with transformer models.
"""trl is an open library for RL with transformer models.

Note:

Expand Down Expand Up @@ -53,6 +53,7 @@
8. Change the version in __init__.py and setup.py to X.X.X+1.dev0 (e.g. VERSION=1.18.3 -> 1.18.4.dev0).
Then push the change with a message 'set dev version'
"""

import os

from setuptools import find_packages, setup
Expand Down
4 changes: 2 additions & 2 deletions trl/extras/best_of_n_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ def generate(
queries = tokenized_query.unsqueeze(0)
elif isinstance(tokenized_query, List):
element_type = type(tokenized_query[0])
if element_type == int:
if element_type is int:
queries = torch.tensor(tokenized_query).unsqueeze(0)
elif element_type == torch.Tensor:
elif element_type is torch.Tensor:
queries = [tensor.reshape((1, -1)) for tensor in tokenized_query]
else:
queries = [torch.tensor(query).reshape((1, -1)) for query in tokenized_query]
Expand Down
1 change: 1 addition & 0 deletions trl/models/sd_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
File copied from diffusers to avoid import issues and make TRL compatible
with most of diffusers versions.
"""

import enum


Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def __init__(
embedding_func: Optional[Callable] = None,
embedding_tokenizer: Optional[PreTrainedTokenizerBase] = None,
):
if type(args) == TrainingArguments:
if type(args) is TrainingArguments:
raise ValueError("Please use `BCOConfig` instead TrainingArguments.")

if args.model_init_kwargs is None:
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def __init__(
model_adapter_name: Optional[str] = None,
ref_adapter_name: Optional[str] = None,
):
if type(args) == TrainingArguments:
if type(args) is TrainingArguments:
raise ValueError("Please use `KTOConfig` instead TrainingArguments.")

if args.model_init_kwargs is None:
Expand Down
12 changes: 6 additions & 6 deletions trl/trainer/ppov2_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,14 +473,14 @@ def repeat_generator():
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
approxkl = 0.5 * (logprobs_diff**2).mean()
approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
pg_clipfrac_stats[
ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx
] = pg_clipfrac
pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
pg_clipfrac
)
pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
vf_clipfrac_stats[
ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx
] = vf_clipfrac
vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
vf_clipfrac
)
entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
gradient_accumulation_idx += 1
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(
peft_config (`Dict`, defaults to `None`):
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
"""
if type(args) == TrainingArguments:
if type(args) is TrainingArguments:
warnings.warn(
"Using `transformers.TrainingArguments` for `args` is deprecated and will be removed in a future version. Please use `RewardConfig` instead.",
FutureWarning,
Expand Down Expand Up @@ -163,7 +163,7 @@ def __init__(
raise ValueError(
"max_length or a tokenizer must be specified when using the default RewardDataCollatorWithPadding"
)
if type(args) == TrainingArguments:
if type(args) is TrainingArguments:
if max_length is None:
warnings.warn(
"When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig."
Expand Down
6 changes: 3 additions & 3 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,9 +396,9 @@ def repeat_generator():
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
approxkl = 0.5 * (logprobs_diff**2).mean()
approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
pg_clipfrac_stats[
ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx
] = pg_clipfrac
pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
pg_clipfrac
)
pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = new_ratio.mean()
Expand Down
Loading