Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate AdapterTrainer Class #218

Merged
merged 32 commits into from
Sep 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
8a6a267
Implement AdapterTRainer with callbacks
hSterz Aug 10, 2021
dc1e1d8
Fix typo
hSterz Aug 10, 2021
3315d00
Adjust import
hSterz Aug 10, 2021
dc7bbab
Adapted run_translation for extended tests
hSterz Aug 22, 2021
e3a2c52
Allowed additional callbacks and logging
hSterz Aug 22, 2021
4ae0c66
style
hSterz Aug 22, 2021
bce7daa
Add changes on develop
hSterz Aug 22, 2021
528bd51
Quality
hSterz Aug 22, 2021
fd0166a
Quality
hSterz Aug 23, 2021
0ab336a
Merge with master
hSterz Aug 24, 2021
21b648c
Quality
hSterz Aug 24, 2021
47935b4
Overwrite save method
hSterz Sep 5, 2021
fbde693
Overwriting _save() and creating and overwriting _load()
hSterz Sep 5, 2021
0f9e68b
Style
hSterz Sep 5, 2021
bf31ebc
Added automatic saving an dloading of heads to trainer
hSterz Sep 6, 2021
3b028e5
Merge branch 'hub' into dev/adapter_trainer
hSterz Sep 6, 2021
5334c20
Style
hSterz Sep 6, 2021
4312330
Style
hSterz Sep 6, 2021
968da30
Fix loading
hSterz Sep 6, 2021
7f31dde
Added tset
hSterz Sep 6, 2021
f3bbb52
Additional Testcase
hSterz Sep 10, 2021
9a0587d
Change Adaptertrainer to only train adapters
hSterz Sep 10, 2021
cb930fa
Fix test
hSterz Sep 10, 2021
345225d
Quality
hSterz Sep 10, 2021
c177e2f
Fix test
hSterz Sep 13, 2021
cf10f1f
Adapt examples to new AdapterTrainer
hSterz Sep 14, 2021
be2865e
merge
hSterz Sep 14, 2021
9b5d859
Style
hSterz Sep 14, 2021
eca7283
Overwrite remove unused columns method
hSterz Sep 14, 2021
b67595e
Add extended adapter trainer test
hSterz Sep 15, 2021
b9e41ec
Fix
hSterz Sep 15, 2021
05ef6ba
Fix
hSterz Sep 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions examples/adapterfusion/run_fusion_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from transformers import (
AdapterArguments,
AdapterTrainer,
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
Expand All @@ -37,7 +38,6 @@
from transformers import GlueDataTrainingArguments as DataTrainingArguments
from transformers import (
HfArgumentParser,
Trainer,
TrainingArguments,
glue_compute_metrics,
glue_output_modes,
Expand Down Expand Up @@ -203,15 +203,12 @@ def compute_metrics(p: EvalPrediction) -> Dict:
preds = np.squeeze(p.predictions)
return glue_compute_metrics(data_args.task_name, preds, p.label_ids)

# Initialize our Trainer
trainer = Trainer(
trainer = AdapterTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics,
do_save_full_model=False,
do_save_adapter_fusion=True,
)

# Training
Expand Down
7 changes: 3 additions & 4 deletions examples/dependency-parsing/run_udp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
MultiLingAdapterArguments,
set_seed,
)
from utils_udp import UD_HEAD_LABELS, DependencyParsingTrainer, UDTrainingArguments
from utils_udp import UD_HEAD_LABELS, DependencyParsingAdapterTrainer, DependencyParsingTrainer, UDTrainingArguments


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -245,13 +245,12 @@ def main():
# Initialize our Trainer
# HACK: Set this attribute to False to prevent label columns from being deleted
training_args.remove_unused_columns = False
trainer = DependencyParsingTrainer(
trainer_class = DependencyParsingAdapterTrainer if adapter_args.train_adapter else DependencyParsingTrainer
trainer = trainer_class(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
do_save_full_model=not adapter_args.train_adapter,
do_save_adapters=adapter_args.train_adapter,
)

# Training
Expand Down
13 changes: 5 additions & 8 deletions examples/dependency-parsing/utils_udp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from tqdm import tqdm

from transformers import (
AdapterTrainer,
DataCollator,
EvalPrediction,
PreTrainedModel,
Expand Down Expand Up @@ -186,10 +187,6 @@ def __init__(
model_init: Callable[[], PreTrainedModel] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
do_save_full_model: bool = True,
do_save_adapters: bool = False,
do_save_adapter_fusion: bool = False,
adapter_names: Optional[List[List[str]]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
**kwargs,
):
Expand All @@ -203,10 +200,6 @@ def __init__(
model_init,
compute_metrics,
callbacks,
do_save_full_model,
do_save_adapters,
do_save_adapter_fusion,
adapter_names,
optimizers,
**kwargs,
)
Expand Down Expand Up @@ -362,3 +355,7 @@ def _prediction_loop(

# Add predictions_rels to output, even though we are only interested in the metrics
return PredictionOutput(predictions=predictions_rels, label_ids=None, metrics=results)


class DependencyParsingAdapterTrainer(AdapterTrainer, DependencyParsingTrainer):
pass
6 changes: 3 additions & 3 deletions examples/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from transformers import (
CONFIG_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
AdapterTrainer,
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
Expand Down Expand Up @@ -480,16 +481,15 @@ def group_texts(examples):
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))

# Initialize our Trainer
trainer = Trainer(
trainer_class = AdapterTrainer if adapter_args.train_adapter else Trainer
trainer = trainer_class(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
# Data collator will default to DataCollatorWithPadding, so we change it.
data_collator=default_data_collator,
do_save_full_model=not adapter_args.train_adapter,
do_save_adapters=adapter_args.train_adapter,
)

# Training
Expand Down
6 changes: 3 additions & 3 deletions examples/language-modeling/run_mlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from transformers import (
CONFIG_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
AdapterTrainer,
AutoConfig,
AutoModelForMaskedLM,
AutoTokenizer,
Expand Down Expand Up @@ -512,15 +513,14 @@ def group_texts(examples):
)

# Initialize our Trainer
trainer = Trainer(
trainer_class = AdapterTrainer if adapter_args.train_adapter else Trainer
trainer = trainer_class(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
do_save_full_model=not adapter_args.train_adapter,
do_save_adapters=adapter_args.train_adapter,
)

# Training
Expand Down
6 changes: 3 additions & 3 deletions examples/multiple-choice/run_swag.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import transformers.adapters.composition as ac
from transformers import (
AdapterConfig,
AdapterTrainer,
AutoConfig,
AutoModelForMultipleChoice,
AutoTokenizer,
Expand Down Expand Up @@ -437,16 +438,15 @@ def compute_metrics(eval_predictions):
return {"accuracy": (preds == label_ids).astype(np.float32).mean().item()}

# Initialize our Trainer
trainer = Trainer(
trainer_class = AdapterTrainer if adapter_args.train_adapter else Trainer
trainer = trainer_class(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
do_save_full_model=not adapter_args.train_adapter,
do_save_adapters=adapter_args.train_adapter,
)

# Training
Expand Down
7 changes: 3 additions & 4 deletions examples/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from datasets import load_dataset, load_metric

import transformers
from trainer_qa import QuestionAnsweringTrainer
from trainer_qa import QuestionAnsweringAdapterTrainer, QuestionAnsweringTrainer
from transformers import (
AdapterConfig,
AutoConfig,
Expand Down Expand Up @@ -599,7 +599,8 @@ def compute_metrics(p: EvalPrediction):
return metric.compute(predictions=p.predictions, references=p.label_ids)

# Initialize our Trainer
trainer = QuestionAnsweringTrainer(
trainer_class = QuestionAnsweringAdapterTrainer if adapter_args.train_adapter else QuestionAnsweringTrainer
trainer = trainer_class(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
Expand All @@ -609,8 +610,6 @@ def compute_metrics(p: EvalPrediction):
data_collator=data_collator,
post_process_function=post_processing_function,
compute_metrics=compute_metrics,
do_save_full_model=not adapter_args.train_adapter,
do_save_adapters=adapter_args.train_adapter,
)

# Training
Expand Down
6 changes: 5 additions & 1 deletion examples/question-answering/trainer_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
A subclass of `Trainer` specific to Question-Answering tasks
"""

from transformers import Trainer, is_torch_tpu_available
from transformers import AdapterTrainer, Trainer, is_torch_tpu_available
from transformers.trainer_utils import PredictionOutput


Expand Down Expand Up @@ -103,3 +103,7 @@ def predict(self, predict_dataset, predict_examples, ignore_keys=None, metric_ke
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics)


class QuestionAnsweringAdapterTrainer(QuestionAnsweringTrainer, AdapterTrainer):
pass
6 changes: 3 additions & 3 deletions examples/summarization/run_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
EarlyStoppingCallback,
HfArgumentParser,
MultiLingAdapterArguments,
Seq2SeqAdapterTrainer,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
set_seed,
Expand Down Expand Up @@ -585,16 +586,15 @@ def compute_metrics(eval_preds):
training_args.load_best_model_at_end = True

# Initialize our Trainer
trainer = Seq2SeqTrainer(
trainer_class = Seq2SeqAdapterTrainer if adapter_args.train_adapter else Seq2SeqTrainer
trainer = trainer_class(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
do_save_full_model=not adapter_args.train_adapter,
do_save_adapters=adapter_args.train_adapter,
)
if data_args.patience and data_args.patience > 0:
callback = EarlyStoppingCallback(early_stopping_patience=data_args.patience)
Expand Down
6 changes: 3 additions & 3 deletions examples/text-classification/run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import transformers.adapters.composition as ac
from transformers import (
AdapterConfig,
AdapterTrainer,
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
Expand Down Expand Up @@ -515,16 +516,15 @@ def compute_metrics(p: EvalPrediction):
data_collator = None

# Initialize our Trainer
trainer = Trainer(
trainer_class = AdapterTrainer if adapter_args.train_adapter else Trainer
trainer = trainer_class(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
compute_metrics=compute_metrics,
tokenizer=tokenizer,
data_collator=data_collator,
do_save_full_model=not adapter_args.train_adapter,
do_save_adapters=adapter_args.train_adapter,
)

# Training
Expand Down
6 changes: 3 additions & 3 deletions examples/text-classification/run_glue_alt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import transformers.adapters.composition as ac
from transformers import (
AdapterConfig,
AdapterTrainer,
AutoConfig,
AutoModelWithHeads,
AutoTokenizer,
Expand Down Expand Up @@ -402,7 +403,8 @@ def compute_metrics(p: EvalPrediction):
return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}

# Initialize our Trainer
trainer = Trainer(
trainer_class = AdapterTrainer if adapter_args.train_adapter else Trainer
trainer = trainer_class(
model=model,
args=training_args,
train_dataset=train_dataset,
Expand All @@ -411,8 +413,6 @@ def compute_metrics(p: EvalPrediction):
tokenizer=tokenizer,
# Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding.
data_collator=default_data_collator if data_args.pad_to_max_length else None,
do_save_full_model=not adapter_args.train_adapter,
do_save_adapters=adapter_args.train_adapter,
)

# Training
Expand Down
6 changes: 3 additions & 3 deletions examples/token-classification/run_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import transformers.adapters.composition as ac
from transformers import (
AdapterConfig,
AdapterTrainer,
AutoConfig,
AutoModelForTokenClassification,
AutoTokenizer,
Expand Down Expand Up @@ -518,16 +519,15 @@ def compute_metrics(p):
}

# Initialize our Trainer
trainer = Trainer(
trainer_class = AdapterTrainer if adapter_args.train_adapter else Trainer
trainer = trainer_class(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
do_save_full_model=not adapter_args.train_adapter,
do_save_adapters=adapter_args.train_adapter,
)

# Training
Expand Down
6 changes: 3 additions & 3 deletions examples/translation/run_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
MBartTokenizer,
MBartTokenizerFast,
MultiLingAdapterArguments,
Seq2SeqAdapterTrainer,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
default_data_collator,
Expand Down Expand Up @@ -581,16 +582,15 @@ def compute_metrics(eval_preds):
training_args.load_best_model_at_end = True

# Initialize our Trainer
trainer = Seq2SeqTrainer(
trainer_class = Seq2SeqAdapterTrainer if adapter_args.train_adapter else Seq2SeqTrainer
trainer = trainer_class(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
do_save_full_model=not adapter_args.train_adapter,
do_save_adapters=adapter_args.train_adapter,
)
if data_args.patience and data_args.patience > 0:
callback = EarlyStoppingCallback(early_stopping_patience=data_args.patience)
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,6 +1188,10 @@
"ModelConfigAdaptersMixin",
"ModelWithHeadsAdaptersMixin",
]
_import_structure["adapters.trainer"] = [
"AdapterTrainer",
"Seq2SeqAdapterTrainer",
]
_import_structure["adapters.training"] = [
"AdapterArguments",
"MultiLingAdapterArguments",
Expand Down Expand Up @@ -2688,6 +2692,7 @@
ModelConfigAdaptersMixin,
ModelWithHeadsAdaptersMixin,
)
from .adapters.trainer import AdapterTrainer, Seq2SeqAdapterTrainer
from .adapters.training import AdapterArguments, MultiLingAdapterArguments
from .adapters.utils import (
ADAPTER_CACHE,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,11 @@ def save_all_adapters(
custom_weights_loaders=custom_weights_loaders,
)

def save_all_heads(self, save_directory):
for head_name in self.heads:
save_path = join(save_directory, head_name)
self.save_head(save_path, head_name)

def get_labels(self):
return list(self.config.id2label.values())

Expand Down
4 changes: 4 additions & 0 deletions src/transformers/adapters/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def __init__(
if down_sample is None:
self.down_sample = self.input_size // 2

# ensure that the down sample size is at least 1
if self.down_sample < 1:
self.down_sample = 1

# Linear down projection of the input
seq_list.append(nn.Linear(self.input_size, self.down_sample))

Expand Down
Loading