Skip to content

Commit

Permalink
Merge pull request #3480 from janpf/qlora
Browse files Browse the repository at this point in the history
Add PEFT training and explicit kwarg passthrough
  • Loading branch information
alanakbik authored Jul 15, 2024
2 parents 08b45e9 + 6bcc677 commit b7cc211
Showing 1 changed file with 64 additions and 7 deletions.
71 changes: 64 additions & 7 deletions flair/embeddings/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,11 @@ def __init__(
force_max_length: bool = False,
needs_manual_ocr: Optional[bool] = None,
use_context_separator: bool = True,
transformers_tokenizer_kwargs: Dict[str, Any] = {},
transformers_config_kwargs: Dict[str, Any] = {},
transformers_model_kwargs: Dict[str, Any] = {},
peft_config=None,
peft_gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = {},
**kwargs,
) -> None:
"""Instantiate transformers embeddings.
Expand All @@ -1023,6 +1028,11 @@ def __init__(
force_max_length: If True, the tokenizer will always pad the sequences to maximum length.
needs_manual_ocr: If True, bounding boxes will be calculated manually. This is used for models like `layoutlm <https://huggingface.co/docs/transformers/model_doc/layoutlm>`_ where the tokenizer doesn't compute the bounding boxes itself.
use_context_separator: If True, the embedding will hold an additional token to allow the model to distingulish between context and prediction.
transformers_tokenizer_kwargs: Further values forwarded to the initialization of the transformers tokenizer
transformers_config_kwargs: Further values forwarded to the initialization of the transformers config
transformers_model_kwargs: Further values forwarded to the initialization of the transformers model
peft_config: If set, the model will be trained using adapters and optionally QLoRA. Should be of type "PeftConfig" or subtype
peft_gradient_checkpointing_kwargs: Further values used when preparing the model for kbit training. Only used if peft_config is set.
**kwargs: Further values forwarded to the transformers config
"""
self.instance_parameters = self.get_instance_parameters(locals=locals())
Expand All @@ -1042,7 +1052,9 @@ def __init__(

if tokenizer_data is None:
# load tokenizer and transformer model
self.tokenizer = AutoTokenizer.from_pretrained(model, add_prefix_space=True, **kwargs)
self.tokenizer = AutoTokenizer.from_pretrained(
model, add_prefix_space=True, **transformers_tokenizer_kwargs, **kwargs
)
try:
self.feature_extractor = AutoFeatureExtractor.from_pretrained(model, apply_ocr=False)
except OSError:
Expand All @@ -1060,22 +1072,67 @@ def is_supported_t5_model(config: PretrainedConfig) -> bool:
return getattr(config, "model_type", "") in t5_supported_model_types

if saved_config is None:
config = AutoConfig.from_pretrained(model, output_hidden_states=True, **kwargs)
config = AutoConfig.from_pretrained(
model, output_hidden_states=True, **transformers_config_kwargs, **kwargs
)

if is_supported_t5_model(config):
from transformers import T5EncoderModel

transformer_model = T5EncoderModel.from_pretrained(model, config=config)
transformer_model = T5EncoderModel.from_pretrained(
model, config=config, **transformers_model_kwargs, **kwargs
)
else:
transformer_model = AutoModel.from_pretrained(model, config=config)
transformer_model = AutoModel.from_pretrained(
model, config=config, **transformers_model_kwargs, **kwargs
)
else:
if is_supported_t5_model(saved_config):
from transformers import T5EncoderModel

transformer_model = T5EncoderModel(saved_config, **kwargs)
transformer_model = T5EncoderModel(saved_config, **transformers_model_kwargs, **kwargs)
else:
transformer_model = AutoModel.from_config(saved_config, **kwargs)
transformer_model = transformer_model.to(flair.device)
transformer_model = AutoModel.from_config(saved_config, **transformers_model_kwargs, **kwargs)
try:
transformer_model = transformer_model.to(flair.device)
except ValueError as e:
# if model is quantized by BitsAndBytes this will fail
if "Please use the model as it is" not in str(e):
raise e

if peft_config is not None:
# add adapters for finetuning
try:
from peft import (
TaskType,
get_peft_model,
prepare_model_for_kbit_training,
)
except ImportError:
log.error("You cannot use the PEFT finetuning without peft being installed")
raise
# peft_config: PeftConfig
if peft_config.task_type is None:
peft_config.task_type = TaskType.FEATURE_EXTRACTION
if peft_config.task_type != TaskType.FEATURE_EXTRACTION:
log.warn("The task type for PEFT should be set to FEATURE_EXTRACTION, as it is the only supported type")
if (
"load_in_4bit" in {**kwargs, **transformers_model_kwargs}
or "load_in_8bit" in {**kwargs, **transformers_model_kwargs}
or "quantization_config" in {**kwargs, **transformers_model_kwargs}
):
transformer_model = prepare_model_for_kbit_training(
transformer_model,
**(peft_gradient_checkpointing_kwargs or {}),
)
transformer_model = get_peft_model(model=transformer_model, peft_config=peft_config)

trainable_params, all_param = transformer_model.get_nb_trainable_parameters()
log.info(
f"trainable params: {trainable_params:,d} || "
f"all params: {all_param:,d} || "
f"trainable%: {100 * trainable_params / all_param:.4f}"
)

self.truncate = True
self.force_max_length = force_max_length
Expand Down

0 comments on commit b7cc211

Please sign in to comment.