Skip to content

Commit

Permalink
Fix row parallel lora layers parameters initialization bug (PaddlePad…
Browse files Browse the repository at this point in the history
…dle#9425)

* Fix row parallel lora layers parameters initialization bug

* Fix ColumnParallel LoRA layers
  • Loading branch information
will-jl944 authored Nov 16, 2024
1 parent 5f4dd96 commit 0503fe0
Showing 1 changed file with 44 additions and 34 deletions.
78 changes: 44 additions & 34 deletions paddlenlp/peft/lora/lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import math
from contextlib import nullcontext
from typing import Optional

import paddle
Expand All @@ -22,6 +23,7 @@
from paddle.distributed.fleet.meta_parallel import (
ColumnParallelLinear,
RowParallelLinear,
get_rng_state_tracker,
)

from ...transformers import linear_utils
Expand Down Expand Up @@ -50,6 +52,10 @@
from .lora_quick_layers import quick_lora


def rng_ctx(is_mp: bool, in_dynamic_mode: bool):
return get_rng_state_tracker().rng_state() if (is_mp and in_dynamic_mode) else nullcontext()


class LoRALinear(nn.Linear):
# LoRA implemented in a dense layer
def __init__(
Expand Down Expand Up @@ -198,14 +204,15 @@ def __init__(
self.name = self._name

# Actual trainable parameters
self.lora_A = self.create_parameter(
shape=[self.input_size_per_partition, r],
dtype=self._dtype,
is_bias=False,
attr=paddle.ParamAttr(
initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu")
),
)
with rng_ctx(self.is_mp, paddle.in_dynamic_mode()):
self.lora_A = self.create_parameter(
shape=[self.input_size_per_partition, r],
dtype=self._dtype,
is_bias=False,
attr=paddle.ParamAttr(
initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu")
),
)
self.lora_B = self.create_parameter(
shape=[r, self.out_features],
dtype=self._dtype,
Expand Down Expand Up @@ -345,14 +352,15 @@ def __init__(
self.name = self._name

# Actual trainable parameters
self.lora_A = self.create_parameter(
shape=[self.input_size_per_partition, r],
dtype=self._dtype,
is_bias=False,
attr=paddle.ParamAttr(
initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu")
),
)
with rng_ctx(self.is_mp, paddle.in_dynamic_mode()):
self.lora_A = self.create_parameter(
shape=[self.input_size_per_partition, r],
dtype=self._dtype,
is_bias=False,
attr=paddle.ParamAttr(
initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu")
),
)
self.lora_B = self.create_parameter(
shape=[r, self.out_features],
dtype=self._dtype,
Expand Down Expand Up @@ -468,15 +476,16 @@ def __init__(
attr=lora_A_weight_attr,
)
self.lora_A.is_distributed = False
self.lora_B = self.create_parameter(
shape=[r, self.output_size_per_partition],
dtype=self._dtype,
is_bias=False,
attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.0),
learning_rate=lora_plus_scale,
),
)
with rng_ctx(self.is_mp, paddle.in_dynamic_mode()):
self.lora_B = self.create_parameter(
shape=[r, self.output_size_per_partition],
dtype=self._dtype,
is_bias=False,
attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.0),
learning_rate=lora_plus_scale,
),
)

self.lora_B.is_distributed = True
self.lora_B.split_axis = 1
Expand Down Expand Up @@ -599,15 +608,16 @@ def __init__(
self.lora_A.is_distributed = False
mark_as_sequence_parallel_parameter(self.lora_A)

self.lora_B = self.create_parameter(
shape=[r, self.output_size_per_partition],
dtype=self._dtype,
is_bias=False,
attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.0),
learning_rate=lora_plus_scale,
),
)
with rng_ctx(self.is_mp, paddle.in_dynamic_mode()):
self.lora_B = self.create_parameter(
shape=[r, self.output_size_per_partition],
dtype=self._dtype,
is_bias=False,
attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.0),
learning_rate=lora_plus_scale,
),
)

self.lora_B.is_distributed = True
self.lora_B.split_axis = 1
Expand Down

0 comments on commit 0503fe0

Please sign in to comment.