Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor dispatching logic of LoRA layers #1319

Conversation

BenjaminBossan
Copy link
Member

@BenjaminBossan BenjaminBossan commented Jan 4, 2024

This PR's goal is to simplify the logic for deciding which LoRA layer backend is being used when LoRA is applied to a target layer.

Originally, this refactor was done in #1286 which was about adding the "fast" backend for LoRA, but since that PR was closed, I moved the refactor to this dedicated PR.

Motivation

Right, now, the LoraModel._create_new_module method has become quite complex and hard to read, spanning >100 lines:

@staticmethod
def _create_new_module(lora_config, adapter_name, target, **kwargs):
# avoid eager bnb import
if is_bnb_available():
import bitsandbytes as bnb
from .bnb import Linear8bitLt
if is_bnb_4bit_available():
from .bnb import Linear4bit
gptq_quantization_config = kwargs.get("gptq_quantization_config", None)
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)
loaded_in_8bit = kwargs.pop("loaded_in_8bit", False)
loaded_in_4bit = kwargs.pop("loaded_in_4bit", False)
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
megatron_core = None
if lora_config.megatron_config:
megatron_core = importlib.import_module(lora_config.megatron_core)
if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt):
eightbit_kwargs = kwargs.copy()
eightbit_kwargs.update(
{
"has_fp16_weights": target.state.has_fp16_weights,
"memory_efficient_backward": target.state.memory_efficient_backward,
"threshold": target.state.threshold,
"index": target.index,
}
)
new_module = Linear8bitLt(target, adapter_name, **eightbit_kwargs)
elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit):
fourbit_kwargs = kwargs.copy()
fourbit_kwargs.update(
{
"compute_dtype": target_base_layer.compute_dtype,
"compress_statistics": target_base_layer.weight.compress_statistics,
"quant_type": target_base_layer.weight.quant_type,
}
)
new_module = Linear4bit(target, adapter_name, **fourbit_kwargs)
elif AutoGPTQQuantLinear is not None and isinstance(target_base_layer, AutoGPTQQuantLinear):
new_module = QuantLinear(target, adapter_name, **kwargs)
target.qweight = target_base_layer.qweight
elif isinstance(target_base_layer, torch.nn.Embedding):
embedding_kwargs = kwargs.copy()
embedding_kwargs.pop("fan_in_fan_out", None)
embedding_kwargs.update(lora_config.loftq_config)
new_module = Embedding(target, adapter_name, **embedding_kwargs)
elif isinstance(target_base_layer, torch.nn.Conv2d):
kwargs.update(lora_config.loftq_config)
new_module = Conv2d(target, adapter_name, **kwargs)
elif isinstance(target_base_layer, torch.nn.Linear):
if kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
"Setting fan_in_fan_out to False."
)
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
kwargs.update(lora_config.loftq_config)
new_module = Linear(target, adapter_name, **kwargs)
elif megatron_core and isinstance(
target_base_layer,
(megatron_core.tensor_parallel.ColumnParallelLinear, megatron_core.tensor_parallel.RowParallelLinear),
):
from .tp_layer import LoraParallelLinear
megatron_kwargs = kwargs.copy()
megatron_config = lora_config.megatron_config
if isinstance(megatron_config, dict):
transformer_config_class = megatron_core.transformer.transformer_config.TransformerConfig
megatron_config = transformer_config_class(**lora_config.megatron_config)
megatron_kwargs["megatron_config"] = megatron_config
if megatron_kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to True but the target module is `ColumnParallelLinear` "
"or `RowParallelLinear`. "
"Setting fan_in_fan_out to False."
)
megatron_kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
new_module = LoraParallelLinear(
base_layer=target, adapter_name=adapter_name, backend=megatron_core.tensor_parallel, **megatron_kwargs
)
elif isinstance(target_base_layer, Conv1D):
if not kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to False but the target module is `Conv1D`. "
"Setting fan_in_fan_out to True."
)
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
kwargs.update(lora_config.loftq_config)
new_module = Linear(target, adapter_name, is_target_conv_1d_layer=True, **kwargs)
else:
raise ValueError(
f"Target module {target} is not supported. Currently, only the following modules are supported: "
"`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`."
)
return new_module

The reason for this is that method contains the logic for deciding which LoRA layer backend to use for all the different types of LoRA layers that we have, i.e. normal Linear layer, Conv2d layer, bnb layer, gptq, etc.

This PR greatly simplifies this method (30 LOC) and should make it easier to prevent bugs. It should also simplify adding further backends in the future.

Description

I moved the logic for deciding which layer to match to the respective implementation of the layers. For example, in lora/layer.py, there is now a function called dispatch_default, whose responsibility it is to decide if an Embedding layer, Conv2d layer or Linear layer is the right match. Similarly, in lora/bnb.py, there are now the two functions dispatch_bnb_8bit and dispatch_bnb_4bit to decide what/if any bnb 8bit or 4bit layer should be matched. Same for the gptq and the megatron backend.

This way, the logic to decide what layer to match now resides next to the respective layers. The only thing that LoraModel now needs to do is to collect all the dispatching methods and use the first layer that matches.

Note that the logic to decide if a layer matches is 100% the same, just moved to a different place. Therefore, there should be no difference in the LoRA model being created.

Only LoRA was modified because the other tuners don't have different backends and thus this approach was not necessary for them. The only exception is IA³, which has normal and bnb backend. Since those are only two, it's not as complicated as for LoRA, but if this PR is accepted, I can refactor IA³ in a similar fashion.

Other changes

  • Removed the optional_kwargs argument from _create_and_replace, as it was an unnecessary indirection.
  • Removed the bias argument from kwargs, as it was not used.

Backwards compatibility

This should be fully backwards compatible, as the constructed LoRA model is 100% the same. If there are users that override _create_new_module, their code will probably break, but since this is a private method, we should be fine.

Edit: Also ran regression tests and they passed.

This PR's goal is to simplify the logic for deciding which LoRA layer
backend is being used when LoRA is applied to a target layer.

Originally, this refactor was done in huggingface#1286 which was about adding the
"fast" backend for LoRA, but since that PR was closed, I moved the
refactor to this dedicated PR.

Motivation

Right, now, the LoraModel._create_new_module method has become quite
complex and hard to read, spanning >100 lines:

https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/src/peft/tuners/lora/model.py#L235-L339

The reason for this is that method contains the logic for deciding which
LoRA layer backend to use for all the different types of LoRA layers
that we have, i.e. normal Linear layer, Conv2d layer, bnb layer, gptq,
etc.

Description

To remedy this, I moved the logic for deciding which layer to match to
the respective implementation of the layers. For example, in
lora/layer.py, there is now a function called dispatch_default, whose
responsibility it is to decide if an Embedding layer, Conv2d layer or
Linear layer is the right match. Similarly, in lora/bnb.py, there are
now the two functions dispatch_bnb_8bit and dispatch_bnb_4bit to decide
what/if any bnb 8bit or 4bit layer should be matched.

This way, the logic to decide what layer to match now resides next to
the respective layers. The only thing that LoraModel now needs to do is
to collect all the dispatching methods and use the first layer that
matches.

Note that only LoRA was modified, the other tuners don't have different
backends and thus this approach was not necessary for them. The only
exception is IA³, which has the normal and bnb backend. Since those are
only 2, it's not as complicated as for LoRA, but if this PR is accepted,
I can refactor IA³ in a similar fashion.

Other changes

- Removed the optional_kwargs argument from _create_and_replace, as it
  was an unnecessary indirection.
- Removed the bias argument from kwargs, as it was not used.

Backwards compatibility

This should be fully backwards compatible, as the constructed LoRA model
is 100% the same. If there are users that override _create_new_module,
their code will probably break, but since this is a private method, we
should be fine.
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is much more cleaner indeed, thanks a lot for putting efforts in this refactor @BenjaminBossan and making the codebase cleaner!

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @BenjaminBossan, super clean refactor! 🧹

@BenjaminBossan BenjaminBossan merged commit 54ee2fb into huggingface:main Jan 9, 2024
14 checks passed
@BenjaminBossan BenjaminBossan deleted the refactor-dispatch-lora-layers branch January 9, 2024 11:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants