Skip to content

Commit

Permalink
Add config key for dropout position in LoRA adapter (#8583)
Browse files Browse the repository at this point in the history
Signed-off-by: Michal Futrega <mfutrega@nvidia.com>
Signed-off-by: Pablo Garay <pagaray@nvidia.com>
  • Loading branch information
michal2409 authored and pablo-garay committed Mar 19, 2024
1 parent 6eed58b commit 2fdc14d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def __init__(
dropout: float = 0.0,
model_parallel_config: Optional[ModelParallelConfig] = None,
alpha: float | None = None,
dropout_position: str = 'post',
**kwargs,
):
super().__init__()
Expand All @@ -159,6 +160,7 @@ def __init__(
self.dim = dim
self.alpha = alpha if alpha is not None else self.dim
self.input_is_parallel = input_is_parallel
self.dropout_position = dropout_position

# megatron_gpt_peft_models will provide this arg, but deprecated ones do not.
# in case this arg is not provided, use the dummy default config.
Expand Down Expand Up @@ -256,6 +258,8 @@ def adapter_unfreeze(self,):
super().adapter_unfreeze()

def forward(self, x):
if self.dropout is not None and self.dropout_position == 'pre':
x = self.dropout(x)

if self.norm_position == 'pre':
x = self.layer_norm(x)
Expand All @@ -281,7 +285,7 @@ def forward(self, x):
x = self.layer_norm(x)

# Add dropout if available
if self.dropout is not None:
if self.dropout is not None and self.dropout_position == 'post':
x = self.dropout(x)

x = x * (self.alpha / self.dim)
Expand All @@ -302,6 +306,7 @@ class ParallelLinearAdapterConfig(AdapterConfig):
gather_output: bool = True
input_is_parallel: bool = False
dropout: float = 0.0
dropout_position: str = 'post'
alpha: float | None = None
network_alpha: int | None = None
_target_: str = "{0}.{1}".format(ParallelLinearAdapter.__module__, ParallelLinearAdapter.__name__)
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/nlp/parts/peft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def _create_lora_config(self, cfg, lora_cfg, in_features, out_features, adapter_
"gather_output": False,
"dropout": lora_cfg.adapter_dropout,
"alpha": lora_cfg.get("alpha", lora_cfg.adapter_dim),
"dropout_position": lora_cfg.get("dropout_position", "post"),
}

if lora_cfg.weight_tying:
Expand Down

0 comments on commit 2fdc14d

Please sign in to comment.