Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add adapter support to seq2seq example scripts (translation & summarization) #141

Merged
merged 3 commits into from
Apr 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 79 additions & 3 deletions examples/seq2seq/run_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,15 @@
import transformers
from filelock import FileLock
from transformers import (
AdapterConfig,
AdapterType,
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
EarlyStoppingCallback,
HfArgumentParser,
MultiLingAdapterArguments,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
default_data_collator,
Expand Down Expand Up @@ -209,6 +213,12 @@ class DataTrainingArguments:
source_prefix: Optional[str] = field(
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)
patience: Optional[int] = field(
default=None,
metadata={
"help": "Stop training when the metric specified for `metric_for_best_model` worsend for `patience` number of evaluation calls."
},
)

def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
Expand Down Expand Up @@ -244,13 +254,17 @@ def main():
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.

parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
parser = HfArgumentParser(
(ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, MultiLingAdapterArguments)
)
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
model_args, data_args, training_args, adapter_args = parser.parse_json_file(
json_file=os.path.abspath(sys.argv[1])
)
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args, data_args, training_args, adapter_args = parser.parse_args_into_dataclasses()

if data_args.source_prefix is None and model_args.model_name_or_path in [
"t5-small",
Expand Down Expand Up @@ -357,6 +371,59 @@ def main():
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

# Setup adapters
if adapter_args.train_adapter:
task_name = data_args.dataset_name or "summarization"
# check if adapter already exists, otherwise add it
if task_name not in model.config.adapters:
# resolve the adapter config
adapter_config = AdapterConfig.load(
adapter_args.adapter_config,
non_linearity=adapter_args.adapter_non_linearity,
reduction_factor=adapter_args.adapter_reduction_factor,
)
# load a pre-trained from Hub if specified
if adapter_args.load_adapter:
model.load_adapter(
adapter_args.load_adapter,
AdapterType.text_task,
config=adapter_config,
load_as=task_name,
)
# otherwise, add a fresh adapter
else:
model.add_adapter(task_name, config=adapter_config)
# optionally load a pre-trained language adapter
if adapter_args.load_lang_adapter:
# resolve the language adapter config
lang_adapter_config = AdapterConfig.load(
adapter_args.lang_adapter_config,
non_linearity=adapter_args.lang_adapter_non_linearity,
reduction_factor=adapter_args.lang_adapter_reduction_factor,
)
# load the language adapter from Hub
lang_adapter_name = model.load_adapter(
adapter_args.load_lang_adapter,
AdapterType.text_lang,
config=lang_adapter_config,
load_as=adapter_args.language,
)
else:
lang_adapter_name = None
# Freeze all model weights except of those of this adapter
model.train_adapter([task_name])
# Set the adapters to be used in every forward pass
if lang_adapter_name:
model.set_active_adapters([lang_adapter_name, task_name])
else:
model.set_active_adapters([task_name])
else:
if adapter_args.load_adapter or adapter_args.load_lang_adapter:
raise ValueError(
"Adapters can only be loaded in adapters training mode."
"Use --train_adapter to enable adapter training"
)

prefix = data_args.source_prefix if data_args.source_prefix is not None else ""

# Preprocessing the datasets.
Expand Down Expand Up @@ -511,6 +578,10 @@ def compute_metrics(eval_preds):
result = {k: round(v, 4) for k, v in result.items()}
return result

# Early stopping
if data_args.patience and data_args.patience > 0:
training_args.load_best_model_at_end = True

# Initialize our Trainer
trainer = Seq2SeqTrainer(
model=model,
Expand All @@ -520,7 +591,12 @@ def compute_metrics(eval_preds):
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
do_save_full_model=not adapter_args.train_adapter,
do_save_adapters=adapter_args.train_adapter,
)
if data_args.patience > 0:
callback = EarlyStoppingCallback(early_stopping_patience=data_args.patience)
trainer.add_callback(callback)

# Training
if training_args.do_train:
Expand Down
82 changes: 79 additions & 3 deletions examples/seq2seq/run_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,17 @@

import transformers
from transformers import (
AdapterConfig,
AdapterType,
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
EarlyStoppingCallback,
HfArgumentParser,
MBartTokenizer,
MBartTokenizerFast,
MultiLingAdapterArguments,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
default_data_collator,
Expand Down Expand Up @@ -191,6 +195,12 @@ class DataTrainingArguments:
source_prefix: Optional[str] = field(
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)
patience: Optional[int] = field(
default=None,
metadata={
"help": "Stop training when the metric specified for `metric_for_best_model` worsend for `patience` number of evaluation calls."
},
)

def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
Expand All @@ -213,13 +223,17 @@ def main():
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.

parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
parser = HfArgumentParser(
(ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, MultiLingAdapterArguments)
)
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
model_args, data_args, training_args, adapter_args = parser.parse_json_file(
json_file=os.path.abspath(sys.argv[1])
)
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args, data_args, training_args, adapter_args = parser.parse_args_into_dataclasses()

if data_args.source_prefix is None and model_args.model_name_or_path in [
"t5-small",
Expand Down Expand Up @@ -336,6 +350,59 @@ def main():
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

# Setup adapters
if adapter_args.train_adapter:
task_name = data_args.source_lang.split("_")[0] + "_" + data_args.target_lang.split("_")[0]
# check if adapter already exists, otherwise add it
if task_name not in model.config.adapters:
# resolve the adapter config
adapter_config = AdapterConfig.load(
adapter_args.adapter_config,
non_linearity=adapter_args.adapter_non_linearity,
reduction_factor=adapter_args.adapter_reduction_factor,
)
# load a pre-trained from Hub if specified
if adapter_args.load_adapter:
model.load_adapter(
adapter_args.load_adapter,
AdapterType.text_task,
config=adapter_config,
load_as=task_name,
)
# otherwise, add a fresh adapter
else:
model.add_adapter(task_name, config=adapter_config)
# optionally load a pre-trained language adapter
if adapter_args.load_lang_adapter:
# resolve the language adapter config
lang_adapter_config = AdapterConfig.load(
adapter_args.lang_adapter_config,
non_linearity=adapter_args.lang_adapter_non_linearity,
reduction_factor=adapter_args.lang_adapter_reduction_factor,
)
# load the language adapter from Hub
lang_adapter_name = model.load_adapter(
adapter_args.load_lang_adapter,
AdapterType.text_lang,
config=lang_adapter_config,
load_as=adapter_args.language,
)
else:
lang_adapter_name = None
# Freeze all model weights except of those of this adapter
model.train_adapter([task_name])
# Set the adapters to be used in every forward pass
if lang_adapter_name:
model.set_active_adapters([lang_adapter_name, task_name])
else:
model.set_active_adapters([task_name])
else:
if adapter_args.load_adapter or adapter_args.load_lang_adapter:
raise ValueError(
"Adapters can only be loaded in adapters training mode."
"Use --train_adapter to enable adapter training"
)

prefix = data_args.source_prefix if data_args.source_prefix is not None else ""

# Preprocessing the datasets.
Expand Down Expand Up @@ -478,6 +545,10 @@ def compute_metrics(eval_preds):
result = {k: round(v, 4) for k, v in result.items()}
return result

# Early stopping
if data_args.patience and data_args.patience > 0:
training_args.load_best_model_at_end = True

# Initialize our Trainer
trainer = Seq2SeqTrainer(
model=model,
Expand All @@ -487,7 +558,12 @@ def compute_metrics(eval_preds):
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
do_save_full_model=not adapter_args.train_adapter,
do_save_adapters=adapter_args.train_adapter,
)
if data_args.patience > 0:
callback = EarlyStoppingCallback(early_stopping_patience=data_args.patience)
trainer.add_callback(callback)

# Training
if training_args.do_train:
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# module, but to preserve other warnings. So, don't check this module at all.

__version__ = "2.0.0a1"
__hf_version__ = "4.4.2"

# Work around to update TensorFlow's absl.logging threshold which alters the
# default Python logging output behavior when present.
Expand Down Expand Up @@ -2400,6 +2401,8 @@ def __getattr__(self, name: str):
# Special handling for the version, which is a constant from this module and not imported in a submodule.
if name == "__version__":
return __version__
elif name == "__hf_version__":
return __hf_version__
return super().__getattr__(name)

sys.modules[__name__] = _LazyModule(__name__, _import_structure)
Expand Down
22 changes: 6 additions & 16 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,13 @@

from packaging import version

from .. import __version__
from .. import __hf_version__


def check_min_version(min_version):
if version.parse(__version__) < version.parse(min_version):
if "dev" in min_version:
error_message = (
"This example requires a source install from 🤗 Transformers (see "
"`https://huggingface.co/transformers/installation.html#installing-from-source`),"
)
else:
error_message = f"This example requires a minimum version of {min_version},"
error_message += f" but the version found is {__version__}.\n"
raise ImportError(
error_message
+ (
"Check out https://huggingface.co/transformers/examples.html for the examples corresponding to other "
"versions of 🤗 Transformers."
)
if version.parse(__hf_version__) < version.parse(min_version):
error_message = (
f"This example requires a minimum underlying HuggingFace Transformers version of {min_version},"
)
error_message += f" but the version found is {__hf_version__}.\n"
raise ImportError(error_message)