Skip to content

Commit

Permalink
Megatron distributed parallel linear LoRA (#1092)
Browse files Browse the repository at this point in the history
Adds option to use Megatron's ColumnParallelLinear and RowParallelLinear
for LoRA linear layers, leading to improved performance when using LoRA
with Megatron.
  • Loading branch information
zhangsheng377 authored Nov 30, 2023
1 parent 2b901ee commit 2674f5e
Show file tree
Hide file tree
Showing 5 changed files with 381 additions and 0 deletions.
26 changes: 26 additions & 0 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,32 @@ class LoraConfig(PeftConfig):
)
},
)
megatron_config: Optional[dict] = field(
default=None,
metadata={
"help": (
"The TransformerConfig from Megatron, it is used to create LoRA's parallel linear layer."
"You can get it like this, `core_transformer_config_from_args(get_args())`, "
"this two functions are from Megatron."
"You need to specify this parameter when you want to loraize the ColumnParallelLinear and "
"RowParallelLinear layers of megatron."
"It should be noted that we may not be able to use the `save_pretrained` and `from_pretrained` "
"functions, because TransformerConfig may not necessarily be serialized."
"But when using megatron, we can use `get_peft_model_state_dict` function and "
"megatron's framework, they can also save and load models and configurations."
)
},
)
megatron_core: Optional[str] = field(
default="megatron.core",
metadata={
"help": (
"The core module from Megatron, it is used to judge and create LoRA's parallel linear layer. "
"It only needs to be passed in when you need to use your own modified megatron core module. "
"Otherwise, it will use the default value `megatron.core`. "
)
},
)
# dict type is used when loading config.json
loftq_config: Union[LoftQConfig, dict] = field(
default_factory=dict,
Expand Down
3 changes: 3 additions & 0 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"):
# QuantLinear
in_features, out_features = base_layer.infeatures, base_layer.outfeatures
elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"):
# Megatron ColumnParallelLinear,RowParallelLinear
in_features, out_features = base_layer.input_size, base_layer.output_size
else:
raise ValueError(f"Unsupported layer type {type(base_layer)}")

Expand Down
27 changes: 27 additions & 0 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import math
import operator
import re
Expand Down Expand Up @@ -259,6 +260,10 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
else:
target_base_layer = target

megatron_core = None
if lora_config.megatron_config:
megatron_core = importlib.import_module(lora_config.megatron_core)

if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt):
eightbit_kwargs = kwargs.copy()
eightbit_kwargs.update(
Expand Down Expand Up @@ -300,6 +305,28 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
kwargs.update(lora_config.loftq_config)
new_module = Linear(target, adapter_name, **kwargs)
elif megatron_core and isinstance(
target_base_layer,
(megatron_core.tensor_parallel.ColumnParallelLinear, megatron_core.tensor_parallel.RowParallelLinear),
):
from .tp_layer import LoraParallelLinear

megatron_kwargs = kwargs.copy()
megatron_config = lora_config.megatron_config
if isinstance(megatron_config, dict):
transformer_config_class = megatron_core.transformer.transformer_config.TransformerConfig
megatron_config = transformer_config_class(**lora_config.megatron_config)
megatron_kwargs["megatron_config"] = megatron_config
if megatron_kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to True but the target module is `ColumnParallelLinear` "
"or `RowParallelLinear`. "
"Setting fan_in_fan_out to False."
)
megatron_kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
new_module = LoraParallelLinear(
base_layer=target, adapter_name=adapter_name, backend=megatron_core.tensor_parallel, **megatron_kwargs
)
elif isinstance(target_base_layer, Conv1D):
if not kwargs["fan_in_fan_out"]:
warnings.warn(
Expand Down
158 changes: 158 additions & 0 deletions src/peft/tuners/lora/tp_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from typing import Any

import torch
import torch.nn as nn
import torch.nn.init as init

from .layer import LoraLayer


class LoraParallelLinear(nn.Module, LoraLayer):
"""
When the target layer parallel_linear is RowParallelLinear, in order to keep the input and output shapes
consistent, we need to split the lora matrix A into rows, and the lora_B at this time should be a complete linear
layer; In the same way, when the target layer is ColumnParallelLinear, we perform column segmentation on lora_B,
while lora_A is still a complete linear layer.
"""

def __init__(
self,
base_layer,
adapter_name: str,
backend,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False,
init_lora_weights: bool = True,
**kwargs,
):
super().__init__()
LoraLayer.__init__(self, base_layer=base_layer)

self.backend = backend
self.is_paralle_a = isinstance(base_layer, backend.RowParallelLinear)
self.fan_in_fan_out = fan_in_fan_out
self._active_adapter = adapter_name

megatron_config = kwargs["megatron_config"]
parallel_linear_kwargs = {"megatron_config": megatron_config}
init_method = init.xavier_normal_
if hasattr(megatron_config, "init_method"):
init_method = megatron_config.init_method
input_is_parallel = True
gather_output = False
if isinstance(base_layer, self.backend.RowParallelLinear):
input_is_parallel = base_layer.input_is_parallel
else:
gather_output = base_layer.gather_output
self.update_layer(
adapter_name,
r,
lora_alpha,
lora_dropout,
init_lora_weights,
init_method,
input_is_parallel,
gather_output,
**parallel_linear_kwargs,
)

self.is_target_conv_1d_layer = False

def update_layer(
self,
adapter_name,
r,
lora_alpha,
lora_dropout,
init_lora_weights,
init_method=init.xavier_normal_,
input_is_parallel=True,
gather_output=False,
**parallel_linear_kwargs,
):
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
if lora_dropout > 0.0:
lora_dropout_layer = nn.Dropout(p=lora_dropout)
else:
lora_dropout_layer = nn.Identity()

self.lora_dropout[adapter_name] = lora_dropout_layer

megatron_config = parallel_linear_kwargs["megatron_config"]
# lora needs to be forced to upgrade to 32-bit precision, otherwise it will overflow
megatron_config.params_dtype = torch.float32
if self.is_paralle_a:
lora_a = self.backend.RowParallelLinear(
input_size=self.in_features,
output_size=r,
bias=False,
input_is_parallel=input_is_parallel,
skip_bias_add=True,
init_method=init_method,
config=megatron_config,
)
lora_b = nn.Linear(in_features=r, out_features=self.out_features, bias=False, dtype=torch.float32)
else:
lora_a = nn.Linear(in_features=self.in_features, out_features=r, bias=False, dtype=torch.float32)
lora_b = self.backend.ColumnParallelLinear(
input_size=r,
output_size=self.out_features,
bias=False,
gather_output=gather_output,
init_method=init_method,
config=megatron_config,
)
self.lora_A[adapter_name] = lora_a
self.lora_B[adapter_name] = lora_b
self.scaling[adapter_name] = lora_alpha / r
if init_lora_weights:
self.reset_lora_parameters(adapter_name)

weight = getattr(self.get_base_layer(), "weight", None)
if weight is not None:
# the layer is already completely initialized, this is an update
if weight.dtype.is_floating_point or weight.dtype.is_complex:
self.to(weight.device, dtype=weight.dtype)
else:
self.to(weight.device)
self.set_adapter(self.active_adapters)

def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any):
previous_dtype = x.dtype
# If weight is used for matrix multiplication here, the final aggregation operation of the original
# parallel_linear layer will be missing, so we need to directly call its forward function to obtain the
# output of the original parallel_linear layer.
if self.disable_adapters:
if self.merged:
self.unmerge()
result, bias = self.base_layer(x, *args, **kwargs)
elif self.merged:
result, bias = self.base_layer(x, *args, **kwargs)
else:
result, bias = self.base_layer(x, *args, **kwargs)
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = x.to(lora_A.weight.dtype)

lora_result = lora_A(dropout(x))
if isinstance(lora_result, tuple):
lora_result = lora_result[0]
lora_result = lora_B(lora_result)
if isinstance(lora_result, tuple):
lora_result = lora_result[0]
lora_result = lora_result * scaling

result = result + lora_result

result = result.to(previous_dtype)
return result, bias
Loading

0 comments on commit 2674f5e

Please sign in to comment.