Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Oct 9, 2023
1 parent dbcb02f commit 3a8649b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
14 changes: 7 additions & 7 deletions keras_nlp/layers/modeling/lora_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ class LoraDense(keras.layers.Layer):
"""A LoRA adapter layer for a dense input layer.
This layer implements a low-rank decomposition of a dense transformation, as
described in [LoRA: Low-Rand Adaptation Of Large Language Models](https://arxiv.org/pdf/2106.09685.pdf)
This layer can be used to replace a dense layer with a layer who's
described in [LoRA: Low-Rank Adaptation Of Large Language Models](https://arxiv.org/pdf/2106.09685.pdf)
This layer can be used to replace a dense layer with a layer whose
parameters are mostly frozen.
By default, this layer takes in a `inner_dense` layer, freezes it's
By default, this layer takes in an `inner_dense` layer, freezes its
parameters, and builds a low-rank decomposed update to sum with the original
`inner_dense` output. These update parameters can be merged back into the
`inner_dense` kernel by calling `merge_weights()`.
Expand All @@ -56,8 +56,8 @@ class LoraDense(keras.layers.Layer):
represent a dense transformation on the last axis of the input,
though adding new axes to the output (e.g. a multi-head axis) is
allowed.
rank: int The inner rank of the decomposed dense transformation. The
lower this number, the less trainable parameters the layer will
rank: int. The inner rank of the decomposed dense transformation. The
lower this number, the fewer trainable parameters the layer will
have.
alpha: float. A constant value used for scaling the lora update. The
lora update to the original dense transformation will be scaled by
Expand All @@ -66,7 +66,7 @@ class LoraDense(keras.layers.Layer):
from layer inputs to the inner `rank` intermediate outputs.
freeze_kernel: If true, the kernel of the inner dense layer will have
`trainable` set to False.
freeze_bais: If true, the kernel of the inner dense layer will have
freeze_bias: If true, the kernel of the inner dense layer will have
`trainable` set to False.
**kwargs: other keyword arguments.
Expand Down Expand Up @@ -112,7 +112,7 @@ def __init__(
self,
inner_dense,
rank=8,
alpha=32.0,
alpha=8.0,
lora_a_initializer="variance_scaling",
freeze_kernel=True,
freeze_bias=True,
Expand Down
13 changes: 13 additions & 0 deletions keras_nlp/layers/modeling/lora_dense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,19 @@ def test_merge_einsum(self):
self.assertAllClose(lora_output, merged_lora_output)
self.assertAllClose(lora_output, dense_output)

def test_freezing(self):
inner_dense = keras.layers.Dense(16)
layer = LoraDense(inner_dense, freeze_bias=False)
layer.build((2, 16))
self.assertFalse(inner_dense.kernel.trainable)
self.assertTrue(inner_dense.bias.trainable)

inner_dense = keras.layers.Dense(16)
layer = LoraDense(inner_dense)
layer.build((2, 16))
self.assertFalse(inner_dense.kernel.trainable)
self.assertFalse(inner_dense.bias.trainable)

def test_errors_if_not_dense(self):
with self.assertRaises(ValueError):
LoraDense(keras.layers.Concatenate())
Expand Down

0 comments on commit 3a8649b

Please sign in to comment.