Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pre-commit: replace linters + formatters with Ruff; fix some issues #1300

Merged
merged 8 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 5 additions & 32 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,37 +1,10 @@
repos:
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.0
hooks:
- id: isort
args:
- --profile=black
- --skip-glob=wandb/**/*
- --thirdparty=wandb
- repo: https://github.com/myint/autoflake
rev: v1.4
hooks:
- id: autoflake
args:
- -r
- --exclude=wandb,__init__.py
- --in-place
- --remove-unused-variables
- --remove-all-unused-imports
- repo: https://github.com/python/black
rev: 22.3.0
hooks:
- id: black
args:
- --line-length=119
- --target-version=py38
- --exclude=wandb
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
hooks:
- id: flake8
args:
- --ignore=E203,E501,W503,E128
- --max-line-length=119
- id: ruff
args: [ --fix ]
- id: ruff-format

# - repo: https://github.com/codespell-project/codespell
# rev: v2.1.0
Expand Down
2 changes: 1 addition & 1 deletion examples/hello_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"pad_token_id": tokenizer.eos_token_id,
"max_new_tokens": 20,
}
response_tensor = ppo_trainer.generate([item for item in query_tensor], return_prompt=False, **generation_kwargs)
response_tensor = ppo_trainer.generate(list(query_tensor), return_prompt=False, **generation_kwargs)
response_txt = tokenizer.decode(response_tensor[0])

# 5. define a reward for response
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -163,7 +162,7 @@ def preprocess_function(examples):


def collator(data):
return dict((key, [d[key] for d in data]) for key in data[0])
return {key: [d[key] for d in data] for key in data[0]}


# set seed before initializing value head for deterministic eval
Expand Down
3 changes: 1 addition & 2 deletions examples/research_projects/tools/calculator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -107,7 +106,7 @@ def exact_match_reward(responses, answers=None):
)

# main training loop
for step in range(100):
for _step in range(100):
tasks, answers = generate_data(ppo_config.batch_size)
queries, responses, masks, rewards, histories = text_env.run(tasks, answers=answers)
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
Expand Down
5 changes: 2 additions & 3 deletions examples/research_projects/tools/python_interpreter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -61,9 +60,9 @@ def exact_match_reward(responses, answers=None):
if match_pattern:
predicted_number = float(match_pattern[0])
if predicted_number is not None:
if np.abs((predicted_number - float(answer))) < 0.1:
if np.abs(predicted_number - float(answer)) < 0.1:
reward += 1.0
except: # noqa
except Exception:
pass
rewards.append(torch.tensor(reward))
return rewards
Expand Down
17 changes: 9 additions & 8 deletions examples/research_projects/tools/triviaqa.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -114,7 +113,7 @@ class ScriptArguments:

def data_generator():
for i in range(len(dataset)):
yield dataset[i]["question"], [item for item in dataset[i]["answer"]["normalized_aliases"]]
yield dataset[i]["question"], list(dataset[i]["answer"]["normalized_aliases"])


gen = data_generator()
Expand All @@ -123,7 +122,7 @@ def data_generator():

def generate_data(n):
tasks, answers = [], []
for i in range(n):
for _i in range(n):
q, a = next(gen)
tasks.append(q)
answers.append(a)
Expand All @@ -143,10 +142,14 @@ def exact_match_reward(responses, answers=None):
return rewards


def tool_fn(x):
# limit the amount of tokens
return tool(x).split("\n")[1][:600]


# text env
tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")
# limit the amount if tokens
tool_fn = lambda x: tool(x).split("\n")[1][:600] # noqa

text_env = TextEnvironment(
model,
tokenizer,
Expand Down Expand Up @@ -184,8 +187,6 @@ def print_trainable_parameters(model):
"answer": [", ".join(item) for item in answers],
}
all_rewards = ppo_trainer.accelerator.gather(torch.tensor(rewards, device=ppo_trainer.accelerator.device))
ppo_trainer.log_stats(
train_stats, texts, [item for item in all_rewards], columns_to_log=["query", "response", "answer"]
)
ppo_trainer.log_stats(train_stats, texts, list(all_rewards), columns_to_log=["query", "response", "answer"])
if i % 100 == 0:
ppo_trainer.save_pretrained(f"models/{args.model_name}_{args.seed}_{i}_triviaqa")
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -146,7 +145,7 @@ def tokenize(sample):


def collator(data):
return dict((key, [d[key] for d in data]) for key in data[0])
return {key: [d[key] for d in data] for key in data[0]}


# set seed before initializing value head for deterministic eval
Expand Down Expand Up @@ -218,7 +217,7 @@ def collator(data):
response_tensors.append(response.squeeze()[-gen_len:])
batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]

# Compute sentiment score # noqa
# Compute sentiment score
texts = batch["response"]
toxicity_inputs = toxicity_tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(
ppo_trainer.accelerator.device
Expand Down
1 change: 0 additions & 1 deletion examples/scripts/dpo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
5 changes: 2 additions & 3 deletions examples/scripts/ppo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -95,7 +94,7 @@ def tokenize(sample):


def collator(data):
return dict((key, [d[key] for d in data]) for key in data[0])
return {key: [d[key] for d in data] for key in data[0]}


# set seed before initializing value head for deterministic eval
Expand Down Expand Up @@ -171,7 +170,7 @@ def collator(data):
"max_new_tokens": 32,
}

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
for _epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
query_tensors = batch["input_ids"]

# Get response from gpt2
Expand Down
5 changes: 2 additions & 3 deletions examples/scripts/ppo_multi_adapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -97,7 +96,7 @@ def tokenize(example):


def collator(data):
return dict((key, [d[key] for d in data]) for key in data[0])
return {key: [d[key] for d in data] for key in data[0]}


config = PPOConfig(
Expand Down Expand Up @@ -131,7 +130,7 @@ def collator(data):
"max_new_tokens": 32,
}

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
for _epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
question_tensors = batch["input_ids"]

response_tensors = ppo_trainer.generate(
Expand Down
1 change: 0 additions & 1 deletion examples/scripts/reward_modeling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
1 change: 0 additions & 1 deletion examples/scripts/sft.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
26 changes: 16 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
[tool.black]
line-length = 119
target-version = ['py38']

[tool.ruff]
ignore = ["E501", "E741", "W605"]
select = ["E", "F", "I", "W"]
target-version = "py37"
line-length = 119

# Ignore import violations in all `__init__.py` files.
[tool.ruff.per-file-ignores]
"__init__.py" = ["E402", "F401", "F403", "F811"]
[tool.ruff.lint]
ignore = [
"B028", # warning without explicit stacklevel
"C408", # dict() calls (stylistic)
"C901", # function complexity
"E501",
]
extend-select = ["E", "F", "I", "W", "UP", "B", "T", "C"]

[tool.ruff.lint.per-file-ignores]
# Allow prints in auxiliary scripts
"benchmark/**.py" = ["T201"]
"examples/**.py" = ["T201"]
"scripts/**.py" = ["T201"]

[tool.ruff.isort]
[tool.ruff.lint.isort]
lines-after-imports = 2
known-first-party = ["trl"]
2 changes: 1 addition & 1 deletion scripts/log_example_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def main(text_file_name, slack_channel_name=None):
if os.path.isfile(text_file_name):
final_results = {}

file = open(text_file_name, "r")
file = open(text_file_name)
lines = file.readlines()
for line in lines:
result, config_name = line.split(",")
Expand Down
2 changes: 1 addition & 1 deletion scripts/log_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def main(slack_channel_name=None):
for log in Path().glob("*.log"):
section_num_failed = 0
i = 0
with open(log, "r") as f:
with open(log) as f:
for line in f:
line = json.loads(line)
i += 1
Expand Down
2 changes: 1 addition & 1 deletion scripts/stale.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def main():
open_issues = repo.get_issues(state="open")

for issue in open_issues:
comments = sorted([comment for comment in issue.get_comments()], key=lambda i: i.created_at, reverse=True)
comments = sorted(issue.get_comments(), key=lambda i: i.created_at, reverse=True)
last_comment = comments[0] if len(comments) > 0 else None
if (
last_comment is not None
Expand Down
9 changes: 0 additions & 9 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,11 +1,2 @@
[metadata]
license_file = LICENSE

[isort]
ensure_newline_before_comments = True
force_grid_wrap = 0
include_trailing_comma = True
line_length = 119
lines_after_imports = 2
multi_line_output = 3
use_parentheses = True
10 changes: 5 additions & 5 deletions tests/test_no_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ def test_no_peft(self):

# Check that loading a model with `peft` will raise an error
with pytest.raises(ModuleNotFoundError):
import peft # noqa
import peft # noqa: F401

trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.causal_lm_model_id) # noqa
trl_seq2seq_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(self.seq_to_seq_model_id) # noqa
_trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.causal_lm_model_id)
_trl_seq2seq_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(self.seq_to_seq_model_id)

def test_imports_no_peft(self):
with patch.dict(sys.modules, {"peft": None}):
from trl import ( # noqa
from trl import ( # noqa: F401
AutoModelForCausalLMWithValueHead,
AutoModelForSeq2SeqLMWithValueHead,
PPOConfig,
Expand Down Expand Up @@ -141,7 +141,7 @@ def test_ppo_trainer_no_peft(self):
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0), torch.tensor(0.0)]
# train model
train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
train_stats = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
break

# check gradients are not None
Expand Down
Loading
Loading