Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor and a couple of fixes for adapter layer updates #1268

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/peft/tuners/adalora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def __init__(self, base_layer: nn.Module) -> None:
self.ranknum = nn.ParameterDict({})

def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")

self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
if lora_dropout > 0.0:
Expand Down
13 changes: 1 addition & 12 deletions src/peft/tuners/ia3/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(self, base_layer: nn.Module, is_feedforward: bool, **kwargs) -> Non
self.out_features = out_features

def update_layer(self, adapter_name, init_ia3_weights):
# This code works for linear layers, override for other layer types
# Actual trainable parameters
if self.is_feedforward:
weight = torch.randn((1, self.in_features))
Expand Down Expand Up @@ -89,18 +90,6 @@ def __init__(
self._active_adapter = adapter_name
self.update_layer(adapter_name, init_ia3_weights)

def update_layer(self, adapter_name, init_ia3_weights):
# Actual trainable parameters
if self.is_feedforward:
weight = torch.randn((1, self.in_features))
else:
weight = torch.randn((self.out_features, 1))
self.ia3_l[adapter_name] = nn.Parameter(weight)
if init_ia3_weights:
self.reset_ia3_parameters(adapter_name)
self.to(self.get_base_layer().weight.device)
self.set_adapter(self.active_adapters)

def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Expand Down
7 changes: 1 addition & 6 deletions src/peft/tuners/ia3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,7 @@ def _create_and_replace(
"is_feedforward": is_feedforward,
}

if isinstance(target, Conv2d):
target.update_layer(
adapter_name,
ia3_config.init_ia3_weights,
)
elif isinstance(target, Linear):
if isinstance(target, IA3Layer):
target.update_layer(
adapter_name,
ia3_config.init_ia3_weights,
Expand Down
2 changes: 2 additions & 0 deletions src/peft/tuners/loha/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def update_layer(
use_effective_conv2d (`bool`, *optional*, defaults to `False`):
Use parameter effective decomposition for Conv2d with ksize > 1.
"""
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")

self.r[adapter_name] = r
self.alpha[adapter_name] = alpha
Expand Down
2 changes: 2 additions & 0 deletions src/peft/tuners/lokr/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ def update_layer(
decompose_both (`bool`): Perform rank decomposition of left kronecker product matrix.
decompose_factor (`int`): Kronecker product decomposition factor.
"""
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")

self.r[adapter_name] = r
self.alpha[adapter_name] = alpha
Expand Down
6 changes: 3 additions & 3 deletions src/peft/tuners/lora/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def forward(self, x: torch.Tensor):
result += output
return result

def __repr__(self) -> str:
rep = super().__repr__()
return "lora." + rep
def __repr__(self) -> str:
rep = super().__repr__()
return "lora." + rep

# TODO: Check if it is better as suggested by users https://github.com/PanQiWei/AutoGPTQ/pull/102
# def reset_lora_parameters(self, adapter_name):
Expand Down
172 changes: 88 additions & 84 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,106 +72,41 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
self.out_features = out_features

def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora):
# This code works for linear layers, override for other layer types
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
if lora_dropout > 0.0:
lora_dropout_layer = nn.Dropout(p=lora_dropout)
else:
lora_dropout_layer = nn.Identity()

self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
# Actual trainable parameters
if r > 0:
self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False)
self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=False)
if use_rslora:
self.scaling[adapter_name] = lora_alpha / math.sqrt(r)
else:
self.scaling[adapter_name] = lora_alpha / r

if init_lora_weights == "loftq":
self.loftq_init(adapter_name)
elif init_lora_weights:
self.reset_lora_parameters(adapter_name, init_lora_weights)

weight = getattr(self.get_base_layer(), "weight", None)
if weight is not None:
# the layer is already completely initialized, this is an update
if weight.dtype.is_floating_point or weight.dtype.is_complex:
self.to(weight.device, dtype=weight.dtype)
else:
self.to(weight.device)
self.set_adapter(self.active_adapters)

def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora):
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
if lora_dropout > 0.0:
lora_dropout_layer = nn.Dropout(p=lora_dropout)
else:
lora_dropout_layer = nn.Identity()

self.lora_dropout[adapter_name] = lora_dropout_layer
self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
# Actual trainable parameters
base_layer = self.get_base_layer()
if r > 0:
kernel_size = base_layer.kernel_size
stride = base_layer.stride
padding = base_layer.padding
self.lora_A[adapter_name] = nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False)
self.lora_B[adapter_name] = nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False)
if use_rslora:
self.scaling[adapter_name] = lora_alpha / math.sqrt(r)
else:
self.scaling[adapter_name] = lora_alpha / r

if init_lora_weights == "loftq":
self.loftq_init(adapter_name)
elif init_lora_weights:
self.reset_lora_parameters(adapter_name, init_lora_weights)

weight = getattr(base_layer, "weight", None)
if weight is not None:
# the layer is already completely initialized, this is an update
self.to(base_layer.weight.device, dtype=weight.dtype)
self.set_adapter(self.active_adapters)

def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora):
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
if lora_dropout > 0.0:
lora_dropout_layer = nn.Dropout(p=lora_dropout)
self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False)
self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=False)
if use_rslora:
self.scaling[adapter_name] = lora_alpha / math.sqrt(r)
else:
lora_dropout_layer = nn.Identity()

self.lora_dropout[adapter_name] = lora_dropout_layer
# Actual trainable parameters
if r > 0:
weight_A = torch.randn((r, self.in_features))
weight_B = torch.randn((self.out_features, r))
self.lora_embedding_A[adapter_name] = nn.Parameter(weight_A)
self.lora_embedding_B[adapter_name] = nn.Parameter(weight_B)
if use_rslora:
self.scaling[adapter_name] = lora_alpha / math.sqrt(r)
else:
self.scaling[adapter_name] = lora_alpha / r
self.scaling[adapter_name] = lora_alpha / r

if init_lora_weights == "loftq":
self.loftq_init(adapter_name)
elif init_lora_weights:
self.reset_lora_parameters(adapter_name, init_lora_weights)

base_layer = self.get_base_layer()
weight = getattr(base_layer, "weight", None)
if weight is not None:
# the layer is already completely initialized, this is an update
self.to(base_layer.weight.device, dtype=weight.dtype)
# check weight and qweight (for GPTQ)
for weight_name in ("weight", "qweight"):
weight = getattr(self.get_base_layer(), weight_name, None)
if weight is not None:
# the layer is already completely initialized, this is an update
if weight.dtype.is_floating_point or weight.dtype.is_complex:
self.to(weight.device, dtype=weight.dtype)
else:
self.to(weight.device)
break
self.set_adapter(self.active_adapters)

def reset_lora_parameters(self, adapter_name, init_lora_weights):
Expand Down Expand Up @@ -407,7 +342,41 @@ def __init__(
LoraLayer.__init__(self, base_layer)

self._active_adapter = adapter_name
self.update_layer_embedding(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)

def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora):
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")

self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
if lora_dropout > 0.0:
lora_dropout_layer = nn.Dropout(p=lora_dropout)
else:
lora_dropout_layer = nn.Identity()

self.lora_dropout[adapter_name] = lora_dropout_layer
# Actual trainable parameters
weight_A = torch.randn((r, self.in_features))
weight_B = torch.randn((self.out_features, r))
self.lora_embedding_A[adapter_name] = nn.Parameter(weight_A)
self.lora_embedding_B[adapter_name] = nn.Parameter(weight_B)
if use_rslora:
self.scaling[adapter_name] = lora_alpha / math.sqrt(r)
else:
self.scaling[adapter_name] = lora_alpha / r

if init_lora_weights == "loftq":
self.loftq_init(adapter_name)
elif init_lora_weights:
self.reset_lora_parameters(adapter_name, init_lora_weights)

base_layer = self.get_base_layer()
weight = getattr(base_layer, "weight", None)
if weight is not None:
# the layer is already completely initialized, this is an update
self.to(base_layer.weight.device, dtype=weight.dtype)
self.set_adapter(self.active_adapters)

def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
"""
Expand Down Expand Up @@ -551,7 +520,42 @@ def __init__(
LoraLayer.__init__(self, base_layer)

self._active_adapter = adapter_name
self.update_layer_conv2d(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)

def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora):
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")

self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
if lora_dropout > 0.0:
lora_dropout_layer = nn.Dropout(p=lora_dropout)
else:
lora_dropout_layer = nn.Identity()

self.lora_dropout[adapter_name] = lora_dropout_layer
# Actual trainable parameters
base_layer = self.get_base_layer()
kernel_size = base_layer.kernel_size
stride = base_layer.stride
padding = base_layer.padding
self.lora_A[adapter_name] = nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False)
self.lora_B[adapter_name] = nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False)
if use_rslora:
self.scaling[adapter_name] = lora_alpha / math.sqrt(r)
else:
self.scaling[adapter_name] = lora_alpha / r

if init_lora_weights == "loftq":
self.loftq_init(adapter_name)
elif init_lora_weights:
self.reset_lora_parameters(adapter_name, init_lora_weights)

weight = getattr(base_layer, "weight", None)
if weight is not None:
# the layer is already completely initialized, this is an update
self.to(base_layer.weight.device, dtype=weight.dtype)
self.set_adapter(self.active_adapters)

def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
"""
Expand Down
33 changes: 4 additions & 29 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def _create_and_replace(
):
if current_key is None:
raise ValueError("Current Key shouldn't be `None`")

# Regexp matching - Find key which matches current target_name in patterns provided
pattern_keys = list(chain(lora_config.rank_pattern.keys(), lora_config.alpha_pattern.keys()))
target_name_key = next(filter(lambda key: re.match(f".*\.{key}$", current_key), pattern_keys), current_key)
Expand All @@ -164,36 +165,10 @@ def _create_and_replace(
if quantization_config is not None:
kwargs["gptq_quantization_config"] = quantization_config

linear_types = (Linear,)
if is_bnb_available():
from .bnb import Linear8bitLt

linear_types += (Linear8bitLt,)
if is_bnb_4bit_available():
from .bnb import Linear4bit
# note: AdaLoraLayer is a subclass of LoraLayer, we need to exclude it
from peft.tuners.adalora import AdaLoraLayer

linear_types += (Linear4bit,)

# TODO: better deal with that
if isinstance(target, Conv2d):
target.update_layer_conv2d(
adapter_name,
r,
alpha,
lora_config.lora_dropout,
lora_config.init_lora_weights,
lora_config.use_rslora,
)
elif isinstance(target, Embedding):
target.update_layer_embedding(
adapter_name,
r,
alpha,
lora_config.lora_dropout,
lora_config.init_lora_weights,
lora_config.use_rslora,
)
elif isinstance(target, linear_types):
if isinstance(target, LoraLayer) and not isinstance(target, AdaLoraLayer):
target.update_layer(
adapter_name,
r,
Expand Down
2 changes: 2 additions & 0 deletions src/peft/tuners/oft/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def update_layer(
The control strength of COFT. The freedom of rotation. Only has an effect if `coft` is set to True.
block_share (`bool`): Whether to share the OFT parameters between blocks or not.
"""
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")

self.r[adapter_name] = r
self.module_dropout[adapter_name] = module_dropout
Expand Down
16 changes: 16 additions & 0 deletions tests/test_common_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ def test_lora_bnb_quantization_from_pretrained_safetensors(self, quantization):
model.set_adapter("adapter2")
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0))

# check that both adapters are in the same layer
self.assertIn("default", model.base_model.model.model.decoder.layers[0].self_attn.q_proj.lora_A)
self.assertIn("adapter2", model.base_model.model.model.decoder.layers[0].self_attn.q_proj.lora_A)

@require_bitsandbytes
@pytest.mark.multi_gpu_tests
@pytest.mark.single_gpu_tests
Expand Down Expand Up @@ -244,6 +248,10 @@ def test_adalora_bnb_quantization_from_pretrained_safetensors(self, quantization
model.set_adapter("adapter2")
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0))

# check that both adapters are in the same layer
self.assertIn("default", model.base_model.model.model.decoder.layers[0].self_attn.q_proj.lora_A)
self.assertIn("adapter2", model.base_model.model.model.decoder.layers[0].self_attn.q_proj.lora_A)

@require_bitsandbytes
@pytest.mark.multi_gpu_tests
@pytest.mark.single_gpu_tests
Expand Down Expand Up @@ -277,6 +285,10 @@ def test_ia3_bnb_quantization_from_pretrained_safetensors(self, quantization):
model.set_adapter("adapter2")
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0))

# check that both adapters are in the same layer
self.assertIn("default", model.base_model.model.model.decoder.layers[0].self_attn.q_proj.ia3_l)
self.assertIn("adapter2", model.base_model.model.model.decoder.layers[0].self_attn.q_proj.ia3_l)

@pytest.mark.single_gpu_tests
def test_lora_gptq_quantization_from_pretrained_safetensors(self):
r"""
Expand Down Expand Up @@ -311,6 +323,10 @@ def test_lora_gptq_quantization_from_pretrained_safetensors(self):
model.set_adapter("adapter2")
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0))

# check that both adapters are in the same layer
self.assertIn("default", model.base_model.model.model.decoder.layers[0].self_attn.q_proj.lora_A)
self.assertIn("adapter2", model.base_model.model.model.decoder.layers[0].self_attn.q_proj.lora_A)

@require_bitsandbytes
@pytest.mark.multi_gpu_tests
@pytest.mark.single_gpu_tests
Expand Down
Loading