Skip to content

Commit

Permalink
Add LoRA to ConvND layers (#19516)
Browse files Browse the repository at this point in the history
* Add LoRA to `BaseConv`

* Add tests

* Fix typo

* Fix tests

* Fix tests
  • Loading branch information
james77777778 authored Apr 15, 2024
1 parent 4c67dcf commit 38a0caa
Show file tree
Hide file tree
Showing 2 changed files with 340 additions and 9 deletions.
130 changes: 121 additions & 9 deletions keras/layers/convolutional/base_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ class BaseConv(Layer):
are not safe to use when doing asynchronous distributed training.
bias_constraint: Optional projection function to be applied to the
bias after being updated by an `Optimizer`.
lora_rank: Optional integer. If set, the layer's forward pass
will implement LoRA (Low-Rank Adaptation)
with the provided rank. LoRA sets the layer's kernel
to non-trainable and replaces it with a delta over the
original kernel, obtained via multiplying two lower-rank
trainable matrices. This can be useful to reduce the
computation cost of fine-tuning large dense layers.
You can also enable LoRA on an existing layer by calling
`layer.enable_lora(rank)`.
"""

def __init__(
Expand All @@ -92,16 +101,10 @@ def __init__(
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
trainable=True,
name=None,
lora_rank=None,
**kwargs,
):
super().__init__(
trainable=trainable,
name=name,
activity_regularizer=activity_regularizer,
**kwargs,
)
super().__init__(activity_regularizer=activity_regularizer, **kwargs)
self.rank = rank
self.filters = filters
self.groups = groups
Expand All @@ -120,6 +123,8 @@ def __init__(
self.bias_regularizer = regularizers.get(bias_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.lora_rank = lora_rank
self.lora_enabled = False
self.input_spec = InputSpec(min_ndim=self.rank + 2)
self.data_format = self.data_format

Expand Down Expand Up @@ -187,7 +192,7 @@ def build(self, input_shape):
# shape, and make sure the output shape has all positive dimensions.
self.compute_output_shape(input_shape)

self.kernel = self.add_weight(
self._kernel = self.add_weight(
name="kernel",
shape=kernel_shape,
initializer=self.kernel_initializer,
Expand All @@ -209,6 +214,20 @@ def build(self, input_shape):
else:
self.bias = None
self.built = True
if self.lora_rank:
self.enable_lora(self.lora_rank)

@property
def kernel(self):
if not self.built:
raise AttributeError(
"You must build the layer before accessing `kernel`."
)
if self.lora_enabled:
return self._kernel + ops.matmul(
self.lora_kernel_a, self.lora_kernel_b
)
return self._kernel

def convolution_op(self, inputs, kernel):
return ops.conv(
Expand Down Expand Up @@ -248,6 +267,63 @@ def compute_output_shape(self, input_shape):
dilation_rate=self.dilation_rate,
)

def enable_lora(
self, rank, a_initializer="he_uniform", b_initializer="zeros"
):
if self.kernel_constraint:
raise ValueError(
"Lora is incompatible with kernel constraints. "
"In order to enable lora on this layer, remove the "
"`kernel_constraint` argument."
)
if not self.built:
raise ValueError(
"Cannot enable lora on a layer that isn't yet built."
)
if self.lora_enabled:
raise ValueError(
"lora is already enabled. "
"This can only be done once per layer."
)
self._tracker.unlock()
self.lora_kernel_a = self.add_weight(
name="lora_kernel_a",
shape=self._kernel.shape[:-1] + (rank,),
initializer=initializers.get(a_initializer),
regularizer=self.kernel_regularizer,
)
self.lora_kernel_b = self.add_weight(
name="lora_kernel_b",
shape=(rank, self.filters),
initializer=initializers.get(b_initializer),
regularizer=self.kernel_regularizer,
)
self._kernel.trainable = False
self._tracker.lock()
self.lora_enabled = True
self.lora_rank = rank

def save_own_variables(self, store):
# Do nothing if the layer isn't yet built
if not self.built:
return
store["0"] = self.kernel
if self.use_bias:
store["1"] = self.bias

def load_own_variables(self, store):
if not self.lora_enabled:
self._check_load_own_variables(store)
# Do nothing if the layer isn't yet built
if not self.built:
return
self._kernel.assign(store["0"])
if self.use_bias:
self.bias.assign(store["1"])
if self.lora_enabled:
self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))

def get_config(self):
config = super().get_config()
config.update(
Expand Down Expand Up @@ -282,4 +358,40 @@ def get_config(self):
"bias_constraint": constraints.serialize(self.bias_constraint),
}
)
if self.lora_rank:
config["lora_rank"] = self.lora_rank
return config

def _check_load_own_variables(self, store):
all_vars = self._trainable_variables + self._non_trainable_variables
if len(store.keys()) != len(all_vars):
if len(all_vars) == 0 and not self.built:
raise ValueError(
f"Layer '{self.name}' was never built "
"and thus it doesn't have any variables. "
f"However the weights file lists {len(store.keys())} "
"variables for this layer.\n"
"In most cases, this error indicates that either:\n\n"
"1. The layer is owned by a parent layer that "
"implements a `build()` method, but calling the "
"parent's `build()` method did NOT create the state of "
f"the child layer '{self.name}'. A `build()` method "
"must create ALL state for the layer, including "
"the state of any children layers.\n\n"
"2. You need to implement "
"the `def build_from_config(self, config)` method "
f"on layer '{self.name}', to specify how to rebuild "
"it during loading. "
"In this case, you might also want to implement the "
"method that generates the build config at saving time, "
"`def get_build_config(self)`. "
"The method `build_from_config()` is meant "
"to create the state "
"of the layer (i.e. its variables) upon deserialization.",
)
raise ValueError(
f"Layer '{self.name}' expected {len(all_vars)} variables, "
"but received "
f"{len(store.keys())} variables during loading. "
f"Expected: {[v.name for v in all_vars]}"
)
Loading

0 comments on commit 38a0caa

Please sign in to comment.