Skip to content

Commit 61157ea

Browse files
committedNov 25, 2022
refactoring
1 parent 4eaf613 commit 61157ea

16 files changed

+942
-62
lines changed
 

‎.gitignore

+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
pip-wheel-metadata/
24+
share/python-wheels/
25+
*.egg-info/
26+
.installed.cfg
27+
*.egg
28+
MANIFEST
29+
30+
# PyInstaller
31+
# Usually these files are written by a python script from a template
32+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
33+
*.manifest
34+
*.spec
35+
36+
# Installer logs
37+
pip-log.txt
38+
pip-delete-this-directory.txt
39+
40+
# Unit test / coverage reports
41+
htmlcov/
42+
.tox/
43+
.nox/
44+
.coverage
45+
.coverage.*
46+
.cache
47+
nosetests.xml
48+
coverage.xml
49+
*.cover
50+
*.py,cover
51+
.hypothesis/
52+
.pytest_cache/
53+
54+
# Translations
55+
*.mo
56+
*.pot
57+
58+
# Django stuff:
59+
*.log
60+
local_settings.py
61+
db.sqlite3
62+
db.sqlite3-journal
63+
64+
# Flask stuff:
65+
instance/
66+
.webassets-cache
67+
68+
# Scrapy stuff:
69+
.scrapy
70+
71+
# Sphinx documentation
72+
docs/_build/
73+
74+
# PyBuilder
75+
target/
76+
77+
# Jupyter Notebook
78+
.ipynb_checkpoints
79+
80+
# IPython
81+
profile_default/
82+
ipython_config.py
83+
84+
# pyenv
85+
.python-version
86+
87+
# pipenv
88+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
90+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
91+
# install all needed dependencies.
92+
#Pipfile.lock
93+
94+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
95+
__pypackages__/
96+
97+
# Celery stuff
98+
celerybeat-schedule
99+
celerybeat.pid
100+
101+
# SageMath parsed files
102+
*.sage.py
103+
104+
# Environments
105+
.env
106+
.venv
107+
env/
108+
venv/
109+
ENV/
110+
env.bak/
111+
venv.bak/
112+
113+
# Spyder project settings
114+
.spyderproject
115+
.spyproject
116+
117+
# Rope project settings
118+
.ropeproject
119+
120+
# mkdocs documentation
121+
/site
122+
123+
# mypy
124+
.mypy_cache/
125+
.dmypy.json
126+
dmypy.json
127+
128+
# Pyre type checker
129+
.pyre/
130+
131+
# VSCode
132+
.vscode
133+
134+
# IntelliJ
135+
.idea
136+
137+
# Mac .DS_Store
138+
.DS_Store
139+
140+
# More test things
141+
wandb

‎MANIFEST.in

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
include LICENSE

‎Makefile

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
.PHONY: quality style test docs
2+
3+
check_dirs := src
4+
5+
# Check that source code meets quality standards
6+
7+
# this target runs checks on all files
8+
quality:
9+
black --check $(check_dirs)
10+
isort --check-only $(check_dirs)
11+
flake8 $(check_dirs)
12+
python utils/style_doc.py src --max_len 119 --check_only
13+
14+
# Format source code automatically and check is there are any problems left that need manual fixing
15+
style:
16+
black $(check_dirs)
17+
isort $(check_dirs)
18+
python utils/style_doc.py src --max_len 119
19+

‎pyproject.toml

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[tool.black]
2+
line-length = 119
3+
target-version = ['py36']

‎setup.cfg

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
[isort]
2+
default_section = FIRSTPARTY
3+
ensure_newline_before_comments = True
4+
force_grid_wrap = 0
5+
include_trailing_comma = True
6+
known_first_party = pet
7+
known_third_party =
8+
numpy
9+
torch
10+
accelerate
11+
transformers
12+
13+
line_length = 119
14+
lines_after_imports = 2
15+
multi_line_output = 3
16+
use_parentheses = True
17+
18+
[flake8]
19+
ignore = E203, E722, E501, E741, W503, W605
20+
max-line-length = 119
21+
22+
[tool:pytest]
23+
doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS

‎setup.py

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2021 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from setuptools import setup
16+
from setuptools import find_packages
17+
18+
extras = {}
19+
extras["quality"] = ["black ~= 22.0", "isort >= 5.5.4", "flake8 >= 3.8.3"]
20+
extras["dev"] = extras["quality"]
21+
22+
setup(
23+
name="pets",
24+
version="0.1.0.dev0",
25+
description="Parameter-Efficient Tuning at Scale (PETS)",
26+
long_description=open("README.md", "r", encoding="utf-8").read(),
27+
long_description_content_type="text/markdown",
28+
keywords="deep learning",
29+
license="Apache",
30+
author="The HuggingFace team",
31+
author_email="sourab@huggingface.co",
32+
url="https://github.com/huggingface/pets",
33+
package_dir={"": "src"},
34+
packages=find_packages("src"),
35+
entry_points={},
36+
python_requires=">=3.7.0",
37+
install_requires=[
38+
"numpy>=1.17",
39+
"packaging>=20.0",
40+
"psutil",
41+
"pyyaml",
42+
"torch>=1.4.0",
43+
"transformers",
44+
"accelerate",
45+
],
46+
extras_require=extras,
47+
classifiers=[
48+
"Development Status :: 5 - Production/Stable",
49+
"Intended Audience :: Developers",
50+
"Intended Audience :: Education",
51+
"Intended Audience :: Science/Research",
52+
"License :: OSI Approved :: Apache Software License",
53+
"Operating System :: OS Independent",
54+
"Programming Language :: Python :: 3",
55+
"Programming Language :: Python :: 3.7",
56+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
57+
],
58+
)
59+
60+
# Release checklist
61+
# 1. Change the version in __init__.py and setup.py.
62+
# 2. Commit these changes with the message: "Release: VERSION"
63+
# 3. Add a tag in git to mark the release: "git tag VERSION -m 'Adds tag VERSION for pypi' "
64+
# Push the tag to git: git push --tags origin main
65+
# 4. Run the following commands in the top-level directory:
66+
# python setup.py bdist_wheel
67+
# python setup.py sdist
68+
# 5. Upload the package to the pypi test server first:
69+
# twine upload dist/* -r pypitest
70+
# twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
71+
# 6. Check that you can install it in a virtualenv by running:
72+
# pip install -i https://testpypi.python.org/pypi accelerate
73+
# accelerate env
74+
# accelerate test
75+
# 7. Upload the final version to actual pypi:
76+
# twine upload dist/* -r pypi
77+
# 8. Add release notes to the tag in github once everything is looking hunky-dory.
78+
# 9. Update the version in __init__.py, setup.py to the new version "-dev" and push to master

‎src/pet/__init__.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# flake8: noqa
2+
# There's no way to ignore "F401 '...' imported but unused" warnings in this
3+
# module, but to preserve other warnings. So, don't check this module at all.
4+
5+
__version__ = "0.1.0.dev0"
6+
7+
from .pet_model import (
8+
ParameterEfficientTuningModel,
9+
ParameterEfficientTuningModelForSequenceClassification,
10+
PromptEncoderType,
11+
)
12+
from .tuners import (
13+
PrefixEncoder,
14+
PromptEmbedding,
15+
PromptEncoder,
16+
PromptEncoderReparameterizationType,
17+
PromptTuningInit,
18+
)

‎src/pet.py ‎src/pet/pet_model.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
from collections import OrderedDict
21
import enum
32
import warnings
3+
from collections import OrderedDict
4+
45
import torch
6+
from accelerate.state import AcceleratorState
57
from transformers import PreTrainedModel
8+
69
from tuners.p_tuning import PromptEncoder
710
from tuners.prefix_tuning import PrefixEncoder
811
from tuners.prompt_tuning import PromptEmbedding
9-
from accelerate.state import AcceleratorState
1012

1113

1214
class PromptEncoderType(str, enum.Enum):
@@ -88,8 +90,8 @@ def state_dict(self, destination=None, prefix=None, keep_vars=False):
8890

8991
def load_state_dict(self, state_dict, strict: bool = True):
9092
"""
91-
Custom load state dict method that only loads prompt table and prompt encoder
92-
parameters. Matching load method for this class' custom state dict method.
93+
Custom load state dict method that only loads prompt table and prompt encoder parameters. Matching load method
94+
for this class' custom state dict method.
9395
"""
9496
self.prompt_encoder.embedding.load_state_dict({"weight": state_dict["prompt_embeddings"]}, strict)
9597

@@ -187,8 +189,8 @@ def state_dict(self, destination=None, prefix=None, keep_vars=False):
187189

188190
def load_state_dict(self, state_dict, strict: bool = True):
189191
"""
190-
Custom load state dict method that only loads prompt table and prompt encoder
191-
parameters. Matching load method for this class' custom state dict method.
192+
Custom load state dict method that only loads prompt table and prompt encoder parameters. Matching load method
193+
for this class' custom state dict method.
192194
"""
193195
super().load_state_dict(state_dict["prompt_encoder"], strict)
194196
self.model.classifier.load_state_dict(state_dict["classifier"], strict)

‎src/prompt_learning_legacy.py ‎src/pet/prompt_learning_legacy.py

+66-48
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,27 @@
11
import enum
2-
import torch
2+
import functools
33
import math
44
import os
5+
from collections import OrderedDict
56

6-
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
7-
from transformers import PreTrainedModel
8-
from transformers.modeling_outputs import SequenceClassifierOutput
9-
from transformers import AutoModelForSequenceClassification
10-
from datasets import load_dataset
11-
import evaluate
127
import torch
13-
from transformers import AutoTokenizer, get_linear_schedule_with_warmup, set_seed
14-
from torch.utils.data import DataLoader
158
from accelerate import Accelerator
169
from accelerate.state import AcceleratorState
1710
from accelerate.utils.dataclasses import FullyShardedDataParallelPlugin
18-
import functools
19-
from torch.distributed.fsdp import (
20-
FullyShardedDataParallel,
21-
CPUOffload,
22-
)
23-
from torch.distributed.fsdp.wrap import (
24-
enable_wrap,
25-
wrap,
26-
ModuleWrapPolicy,
27-
transformer_auto_wrap_policy,
28-
lambda_auto_wrap_policy,
29-
_or_policy,
11+
from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
12+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
13+
from torch.utils.data import DataLoader
14+
from transformers import (
15+
AutoModelForSequenceClassification,
16+
AutoTokenizer,
17+
PreTrainedModel,
18+
get_linear_schedule_with_warmup,
19+
set_seed,
3020
)
31-
from collections import OrderedDict
21+
from transformers.modeling_outputs import SequenceClassifierOutput
22+
23+
import evaluate
24+
from datasets import load_dataset
3225

3326

3427
class PromptEncoderReparameterizationType(str, enum.Enum):
@@ -49,8 +42,7 @@ class PromptTuningInit(str, enum.Enum):
4942

5043
class PromptEncoder(torch.nn.Module):
5144
"""
52-
The prompt encoder network that is used to generate the virtual
53-
token embeddings for p-tuning.
45+
The prompt encoder network that is used to generate the virtual token embeddings for p-tuning.
5446
"""
5547

5648
def __init__(self, config):
@@ -92,13 +84,23 @@ def __init__(self, config):
9284
)
9385

9486
elif self.encoder_type == PromptEncoderReparameterizationType.MLP:
95-
layers = [torch.nn.Linear(self.input_size, self.hidden_size), torch.nn.ReLU()]
96-
layers.extend([torch.nn.Linear(self.hidden_size, self.hidden_size), torch.nn.ReLU()])
87+
layers = [
88+
torch.nn.Linear(self.input_size, self.hidden_size),
89+
torch.nn.ReLU(),
90+
]
91+
layers.extend(
92+
[
93+
torch.nn.Linear(self.hidden_size, self.hidden_size),
94+
torch.nn.ReLU(),
95+
]
96+
)
9797
layers.append(torch.nn.Linear(self.hidden_size, self.output_size))
9898
self.mlp_head = torch.nn.Sequential(*layers)
9999

100100
else:
101-
raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.")
101+
raise ValueError(
102+
"Prompt encoder type not recognized. " " Please use one of MLP (recommended) or LSTM."
103+
)
102104

103105
def forward(self, indices):
104106
input_embeds = self.embedding(indices)
@@ -130,11 +132,15 @@ def __init__(self, config):
130132
self.trans = torch.nn.Sequential(
131133
torch.nn.Linear(config["token_dim"], config["prompt_hidden_size"]),
132134
torch.nn.Tanh(),
133-
torch.nn.Linear(config["prompt_hidden_size"], config["num_layers"] * 2 * config["token_dim"]),
135+
torch.nn.Linear(
136+
config["prompt_hidden_size"],
137+
config["num_layers"] * 2 * config["token_dim"],
138+
),
134139
)
135140
else:
136141
self.embedding = torch.nn.Embedding(
137-
config["num_virtual_tokens"], config["num_layers"] * 2 * config["token_dim"]
142+
config["num_virtual_tokens"],
143+
config["num_layers"] * 2 * config["token_dim"],
138144
)
139145

140146
def forward(self, prefix: torch.Tensor):
@@ -247,8 +253,8 @@ def state_dict(self, destination=None, prefix=None, keep_vars=False):
247253

248254
def load_state_dict(self, state_dict, strict: bool = True):
249255
"""
250-
Custom load state dict method that only loads prompt table and prompt encoder
251-
parameters. Matching load method for this class' custom state dict method.
256+
Custom load state dict method that only loads prompt table and prompt encoder parameters. Matching load method
257+
for this class' custom state dict method.
252258
"""
253259
self.prompt_encoder.embedding.load_state_dict({"weight": state_dict["prompt_embeddings"]}, strict)
254260

@@ -389,8 +395,8 @@ def state_dict(self, destination=None, prefix=None, keep_vars=False):
389395

390396
def load_state_dict(self, state_dict, strict: bool = True):
391397
"""
392-
Custom load state dict method that only loads prompt table and prompt encoder
393-
parameters. Matching load method for this class' custom state dict method.
398+
Custom load state dict method that only loads prompt table and prompt encoder parameters. Matching load method
399+
for this class' custom state dict method.
394400
"""
395401
super().load_state_dict(state_dict["prompt_encoder"], strict)
396402
self.classifier.load_state_dict(state_dict["classifier"], strict)
@@ -528,7 +534,7 @@ def main():
528534
batch_size = 16
529535
lr = 5e-3
530536
num_epochs = 100
531-
device = "cuda"
537+
# device = "cuda"
532538
seed = 11
533539
set_seed(seed)
534540

@@ -544,7 +550,12 @@ def main():
544550

545551
def tokenize_function(examples):
546552
# max_length=None => use the model max length (it's actually the default)
547-
outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
553+
outputs = tokenizer(
554+
examples["sentence1"],
555+
examples["sentence2"],
556+
truncation=True,
557+
max_length=None,
558+
)
548559
return outputs
549560

550561
# Apply the method we just defined to all the examples in all the splits of the dataset
@@ -564,10 +575,16 @@ def collate_fn(examples):
564575

565576
# Instantiate dataloaders.
566577
train_dataloader = DataLoader(
567-
tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size
578+
tokenized_datasets["train"],
579+
shuffle=True,
580+
collate_fn=collate_fn,
581+
batch_size=batch_size,
568582
)
569583
eval_dataloader = DataLoader(
570-
tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size
584+
tokenized_datasets["validation"],
585+
shuffle=False,
586+
collate_fn=collate_fn,
587+
batch_size=batch_size,
571588
)
572589

573590
# Instantiate optimizer
@@ -582,9 +599,13 @@ def collate_fn(examples):
582599

583600
accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(model)
584601

585-
model, train_dataloader, eval_dataloader, optimizer, lr_scheduler = accelerator.prepare(
586-
model, train_dataloader, eval_dataloader, optimizer, lr_scheduler
587-
)
602+
(
603+
model,
604+
train_dataloader,
605+
eval_dataloader,
606+
optimizer,
607+
lr_scheduler,
608+
) = accelerator.prepare(model, train_dataloader, eval_dataloader, optimizer, lr_scheduler)
588609
accelerator.print(model)
589610

590611
for epoch in range(num_epochs):
@@ -616,17 +637,14 @@ def collate_fn(examples):
616637
accelerator.print(f"epoch {epoch}:", eval_metric)
617638
accelerator.print(f"epoch {epoch} train loss:", total_loss / len(train_dataloader))
618639

640+
from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig
619641
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
620-
from torch.distributed.fsdp.fully_sharded_data_parallel import (
621-
BackwardPrefetch,
622-
CPUOffload,
623-
FullStateDictConfig,
624-
ShardingStrategy,
625-
StateDictType,
626-
)
642+
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
627643

628644
FSDP.set_state_dict_type(
629-
model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
645+
model,
646+
StateDictType.FULL_STATE_DICT,
647+
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
630648
)
631649
state_dict = model.state_dict()
632650
state_dict = model.clean_state_dict(state_dict)

‎src/pet/tuners/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# flake8: noqa
2+
# There's no way to ignore "F401 '...' imported but unused" warnings in this
3+
# module, but to preserve other warnings. So, don't check this module at all
4+
5+
from .p_tuning import PromptEncoder, PromptEncoderReparameterizationType
6+
from .prefix_tuning import PrefixEncoder
7+
from .prompt_tuning import PromptEmbedding, PromptTuningInit
File renamed without changes.

‎src/tuners/p_tuning.py ‎src/pet/tuners/p_tuning.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
import torch
21
import enum
32

3+
import torch
4+
45

56
class PromptEncoderReparameterizationType(str, enum.Enum):
67
MLP = "MLP"
@@ -11,8 +12,7 @@ class PromptEncoderReparameterizationType(str, enum.Enum):
1112
# with some refactor
1213
class PromptEncoder(torch.nn.Module):
1314
"""
14-
The prompt encoder network that is used to generate the virtual
15-
token embeddings for p-tuning.
15+
The prompt encoder network that is used to generate the virtual token embeddings for p-tuning.
1616
"""
1717

1818
def __init__(self, config):
@@ -54,8 +54,16 @@ def __init__(self, config):
5454
)
5555

5656
elif self.encoder_type == PromptEncoderReparameterizationType.MLP:
57-
layers = [torch.nn.Linear(self.input_size, self.hidden_size), torch.nn.ReLU()]
58-
layers.extend([torch.nn.Linear(self.hidden_size, self.hidden_size), torch.nn.ReLU()])
57+
layers = [
58+
torch.nn.Linear(self.input_size, self.hidden_size),
59+
torch.nn.ReLU(),
60+
]
61+
layers.extend(
62+
[
63+
torch.nn.Linear(self.hidden_size, self.hidden_size),
64+
torch.nn.ReLU(),
65+
]
66+
)
5967
layers.append(torch.nn.Linear(self.hidden_size, self.output_size))
6068
self.mlp_head = torch.nn.Sequential(*layers)
6169

‎src/tuners/prefix_tuning.py ‎src/pet/tuners/prefix_tuning.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22

3+
34
# Based on https://github.com/THUDM/P-tuning-v2/blob/main/model/prefix_encoder.py
45
# with some refactor
56
class PrefixEncoder(torch.nn.Module):
@@ -20,11 +21,15 @@ def __init__(self, config):
2021
self.trans = torch.nn.Sequential(
2122
torch.nn.Linear(config["token_dim"], config["prompt_hidden_size"]),
2223
torch.nn.Tanh(),
23-
torch.nn.Linear(config["prompt_hidden_size"], config["num_layers"] * 2 * config["token_dim"]),
24+
torch.nn.Linear(
25+
config["prompt_hidden_size"],
26+
config["num_layers"] * 2 * config["token_dim"],
27+
),
2428
)
2529
else:
2630
self.embedding = torch.nn.Embedding(
27-
config["num_virtual_tokens"], config["num_layers"] * 2 * config["token_dim"]
31+
config["num_virtual_tokens"],
32+
config["num_layers"] * 2 * config["token_dim"],
2833
)
2934

3035
def forward(self, prefix: torch.Tensor):

‎src/tuners/prompt_tuning.py ‎src/pet/tuners/prompt_tuning.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
import torch
21
import enum
32
import math
43

4+
import torch
5+
56

67
class PromptTuningInit(str, enum.Enum):
78
TEXT = "TEXT"
File renamed without changes.

‎utils/style_doc.py

+556
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)
Please sign in to comment.