Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LoRA to ConvND layers #19516

Merged
merged 6 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 121 additions & 9 deletions keras/layers/convolutional/base_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ class BaseConv(Layer):
are not safe to use when doing asynchronous distributed training.
bias_constraint: Optional projection function to be applied to the
bias after being updated by an `Optimizer`.
lora_rank: Optional integer. If set, the layer's forward pass
will implement LoRA (Low-Rank Adaptation)
with the provided rank. LoRA sets the layer's kernel
to non-trainable and replaces it with a delta over the
original kernel, obtained via multiplying two lower-rank
trainable matrices. This can be useful to reduce the
computation cost of fine-tuning large dense layers.
You can also enable LoRA on an existing layer by calling
`layer.enable_lora(rank)`.
"""

def __init__(
Expand All @@ -92,16 +101,10 @@ def __init__(
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
trainable=True,
name=None,
lora_rank=None,
**kwargs,
):
super().__init__(
trainable=trainable,
name=name,
activity_regularizer=activity_regularizer,
**kwargs,
)
super().__init__(activity_regularizer=activity_regularizer, **kwargs)
self.rank = rank
self.filters = filters
self.groups = groups
Expand All @@ -120,6 +123,8 @@ def __init__(
self.bias_regularizer = regularizers.get(bias_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.lora_rank = lora_rank
self.lora_enabled = False
self.input_spec = InputSpec(min_ndim=self.rank + 2)
self.data_format = self.data_format

Expand Down Expand Up @@ -187,7 +192,7 @@ def build(self, input_shape):
# shape, and make sure the output shape has all positive dimensions.
self.compute_output_shape(input_shape)

self.kernel = self.add_weight(
self._kernel = self.add_weight(
name="kernel",
shape=kernel_shape,
initializer=self.kernel_initializer,
Expand All @@ -209,6 +214,20 @@ def build(self, input_shape):
else:
self.bias = None
self.built = True
if self.lora_rank:
self.enable_lora(self.lora_rank)

@property
def kernel(self):
if not self.built:
raise AttributeError(
"You must build the layer before accessing `kernel`."
)
if self.lora_enabled:
return self._kernel + ops.matmul(
self.lora_kernel_a, self.lora_kernel_b
)
return self._kernel

def convolution_op(self, inputs, kernel):
return ops.conv(
Expand Down Expand Up @@ -248,6 +267,63 @@ def compute_output_shape(self, input_shape):
dilation_rate=self.dilation_rate,
)

def enable_lora(
self, rank, a_initializer="he_uniform", b_initializer="zeros"
):
if self.kernel_constraint:
raise ValueError(
"Lora is incompatible with kernel constraints. "
"In order to enable lora on this layer, remove the "
"`kernel_constraint` argument."
)
if not self.built:
raise ValueError(
"Cannot enable lora on a layer that isn't yet built."
)
if self.lora_enabled:
raise ValueError(
"lora is already enabled. "
"This can only be done once per layer."
)
self._tracker.unlock()
self.lora_kernel_a = self.add_weight(
name="lora_kernel_a",
shape=self._kernel.shape[:-1] + (rank,),
initializer=initializers.get(a_initializer),
regularizer=self.kernel_regularizer,
)
self.lora_kernel_b = self.add_weight(
name="lora_kernel_b",
shape=(rank, self.filters),
initializer=initializers.get(b_initializer),
regularizer=self.kernel_regularizer,
)
self._kernel.trainable = False
self._tracker.lock()
self.lora_enabled = True
self.lora_rank = rank

def save_own_variables(self, store):
# Do nothing if the layer isn't yet built
if not self.built:
return
store["0"] = self.kernel
if self.use_bias:
store["1"] = self.bias

def load_own_variables(self, store):
if not self.lora_enabled:
self._check_load_own_variables(store)
# Do nothing if the layer isn't yet built
if not self.built:
return
self._kernel.assign(store["0"])
if self.use_bias:
self.bias.assign(store["1"])
if self.lora_enabled:
self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))

def get_config(self):
config = super().get_config()
config.update(
Expand Down Expand Up @@ -282,4 +358,40 @@ def get_config(self):
"bias_constraint": constraints.serialize(self.bias_constraint),
}
)
if self.lora_rank:
config["lora_rank"] = self.lora_rank
return config

def _check_load_own_variables(self, store):
all_vars = self._trainable_variables + self._non_trainable_variables
if len(store.keys()) != len(all_vars):
if len(all_vars) == 0 and not self.built:
raise ValueError(
f"Layer '{self.name}' was never built "
"and thus it doesn't have any variables. "
f"However the weights file lists {len(store.keys())} "
"variables for this layer.\n"
"In most cases, this error indicates that either:\n\n"
"1. The layer is owned by a parent layer that "
"implements a `build()` method, but calling the "
"parent's `build()` method did NOT create the state of "
f"the child layer '{self.name}'. A `build()` method "
"must create ALL state for the layer, including "
"the state of any children layers.\n\n"
"2. You need to implement "
"the `def build_from_config(self, config)` method "
f"on layer '{self.name}', to specify how to rebuild "
"it during loading. "
"In this case, you might also want to implement the "
"method that generates the build config at saving time, "
"`def get_build_config(self)`. "
"The method `build_from_config()` is meant "
"to create the state "
"of the layer (i.e. its variables) upon deserialization.",
)
raise ValueError(
f"Layer '{self.name}' expected {len(all_vars)} variables, "
"but received "
f"{len(store.keys())} variables during loading. "
f"Expected: {[v.name for v in all_vars]}"
)
Loading