diff --git a/keras/layers/convolutional/base_conv.py b/keras/layers/convolutional/base_conv.py index eebbd53fc1c..689083d9400 100644 --- a/keras/layers/convolutional/base_conv.py +++ b/keras/layers/convolutional/base_conv.py @@ -71,6 +71,15 @@ class BaseConv(Layer): are not safe to use when doing asynchronous distributed training. bias_constraint: Optional projection function to be applied to the bias after being updated by an `Optimizer`. + lora_rank: Optional integer. If set, the layer's forward pass + will implement LoRA (Low-Rank Adaptation) + with the provided rank. LoRA sets the layer's kernel + to non-trainable and replaces it with a delta over the + original kernel, obtained via multiplying two lower-rank + trainable matrices. This can be useful to reduce the + computation cost of fine-tuning large dense layers. + You can also enable LoRA on an existing layer by calling + `layer.enable_lora(rank)`. """ def __init__( @@ -92,16 +101,10 @@ def __init__( activity_regularizer=None, kernel_constraint=None, bias_constraint=None, - trainable=True, - name=None, + lora_rank=None, **kwargs, ): - super().__init__( - trainable=trainable, - name=name, - activity_regularizer=activity_regularizer, - **kwargs, - ) + super().__init__(activity_regularizer=activity_regularizer, **kwargs) self.rank = rank self.filters = filters self.groups = groups @@ -120,6 +123,8 @@ def __init__( self.bias_regularizer = regularizers.get(bias_regularizer) self.kernel_constraint = constraints.get(kernel_constraint) self.bias_constraint = constraints.get(bias_constraint) + self.lora_rank = lora_rank + self.lora_enabled = False self.input_spec = InputSpec(min_ndim=self.rank + 2) self.data_format = self.data_format @@ -187,7 +192,7 @@ def build(self, input_shape): # shape, and make sure the output shape has all positive dimensions. self.compute_output_shape(input_shape) - self.kernel = self.add_weight( + self._kernel = self.add_weight( name="kernel", shape=kernel_shape, initializer=self.kernel_initializer, @@ -209,6 +214,20 @@ def build(self, input_shape): else: self.bias = None self.built = True + if self.lora_rank: + self.enable_lora(self.lora_rank) + + @property + def kernel(self): + if not self.built: + raise AttributeError( + "You must build the layer before accessing `kernel`." + ) + if self.lora_enabled: + return self._kernel + ops.matmul( + self.lora_kernel_a, self.lora_kernel_b + ) + return self._kernel def convolution_op(self, inputs, kernel): return ops.conv( @@ -248,6 +267,63 @@ def compute_output_shape(self, input_shape): dilation_rate=self.dilation_rate, ) + def enable_lora( + self, rank, a_initializer="he_uniform", b_initializer="zeros" + ): + if self.kernel_constraint: + raise ValueError( + "Lora is incompatible with kernel constraints. " + "In order to enable lora on this layer, remove the " + "`kernel_constraint` argument." + ) + if not self.built: + raise ValueError( + "Cannot enable lora on a layer that isn't yet built." + ) + if self.lora_enabled: + raise ValueError( + "lora is already enabled. " + "This can only be done once per layer." + ) + self._tracker.unlock() + self.lora_kernel_a = self.add_weight( + name="lora_kernel_a", + shape=self._kernel.shape[:-1] + (rank,), + initializer=initializers.get(a_initializer), + regularizer=self.kernel_regularizer, + ) + self.lora_kernel_b = self.add_weight( + name="lora_kernel_b", + shape=(rank, self.filters), + initializer=initializers.get(b_initializer), + regularizer=self.kernel_regularizer, + ) + self._kernel.trainable = False + self._tracker.lock() + self.lora_enabled = True + self.lora_rank = rank + + def save_own_variables(self, store): + # Do nothing if the layer isn't yet built + if not self.built: + return + store["0"] = self.kernel + if self.use_bias: + store["1"] = self.bias + + def load_own_variables(self, store): + if not self.lora_enabled: + self._check_load_own_variables(store) + # Do nothing if the layer isn't yet built + if not self.built: + return + self._kernel.assign(store["0"]) + if self.use_bias: + self.bias.assign(store["1"]) + if self.lora_enabled: + self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape)) + self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape)) + def get_config(self): config = super().get_config() config.update( @@ -282,4 +358,40 @@ def get_config(self): "bias_constraint": constraints.serialize(self.bias_constraint), } ) + if self.lora_rank: + config["lora_rank"] = self.lora_rank return config + + def _check_load_own_variables(self, store): + all_vars = self._trainable_variables + self._non_trainable_variables + if len(store.keys()) != len(all_vars): + if len(all_vars) == 0 and not self.built: + raise ValueError( + f"Layer '{self.name}' was never built " + "and thus it doesn't have any variables. " + f"However the weights file lists {len(store.keys())} " + "variables for this layer.\n" + "In most cases, this error indicates that either:\n\n" + "1. The layer is owned by a parent layer that " + "implements a `build()` method, but calling the " + "parent's `build()` method did NOT create the state of " + f"the child layer '{self.name}'. A `build()` method " + "must create ALL state for the layer, including " + "the state of any children layers.\n\n" + "2. You need to implement " + "the `def build_from_config(self, config)` method " + f"on layer '{self.name}', to specify how to rebuild " + "it during loading. " + "In this case, you might also want to implement the " + "method that generates the build config at saving time, " + "`def get_build_config(self)`. " + "The method `build_from_config()` is meant " + "to create the state " + "of the layer (i.e. its variables) upon deserialization.", + ) + raise ValueError( + f"Layer '{self.name}' expected {len(all_vars)} variables, " + "but received " + f"{len(store.keys())} variables during loading. " + f"Expected: {[v.name for v in all_vars]}" + ) diff --git a/keras/layers/convolutional/conv_test.py b/keras/layers/convolutional/conv_test.py index 62ec5de8b0e..281e35a960e 100644 --- a/keras/layers/convolutional/conv_test.py +++ b/keras/layers/convolutional/conv_test.py @@ -1,10 +1,15 @@ +import os + import numpy as np import pytest from absl.testing import parameterized from numpy.lib.stride_tricks import as_strided +from keras import backend from keras import constraints from keras import layers +from keras import models +from keras import saving from keras import testing @@ -538,6 +543,220 @@ def test_bad_init_args(self): ): layers.Conv2D(filters=5, kernel_size=(2, 2), groups=2) + @parameterized.named_parameters( + { + "testcase_name": "conv1d_kernel_size3_strides1", + "conv_cls": layers.Conv1D, + "filters": 6, + "kernel_size": 3, + "strides": 1, + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + "groups": 1, + "input_shape": (None, 5, 4), + "output_shape": (None, 3, 6), + }, + { + "testcase_name": "conv1d_kernel_size2_strides2", + "conv_cls": layers.Conv1D, + "filters": 6, + "kernel_size": 2, + "strides": 2, + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + "groups": 2, + "input_shape": (None, 5, 4), + "output_shape": (None, 2, 6), + }, + { + "testcase_name": "conv2d_kernel_size3_strides1", + "conv_cls": layers.Conv2D, + "filters": 6, + "kernel_size": 3, + "strides": 1, + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + "groups": 1, + "input_shape": (None, 5, 5, 4), + "output_shape": (None, 3, 3, 6), + }, + { + "testcase_name": "conv2d_kernel_size2_strides2", + "conv_cls": layers.Conv2D, + "filters": 6, + "kernel_size": 2, + "strides": 2, + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + "groups": 2, + "input_shape": (None, 5, 5, 4), + "output_shape": (None, 2, 2, 6), + }, + { + "testcase_name": "conv3d_kernel_size3_strides1", + "conv_cls": layers.Conv3D, + "filters": 6, + "kernel_size": 3, + "strides": 1, + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + "groups": 1, + "input_shape": (None, 5, 5, 5, 4), + "output_shape": (None, 3, 3, 3, 6), + }, + { + "testcase_name": "conv3d_kernel_size2_strides2", + "conv_cls": layers.Conv3D, + "filters": 6, + "kernel_size": 2, + "strides": 2, + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + "groups": 2, + "input_shape": (None, 5, 5, 5, 4), + "output_shape": (None, 2, 2, 2, 6), + }, + ) + @pytest.mark.requires_trainable_backend + def test_enable_lora( + self, + conv_cls, + filters, + kernel_size, + strides, + padding, + data_format, + dilation_rate, + groups, + input_shape, + output_shape, + ): + if conv_cls not in (layers.Conv1D, layers.Conv2D, layers.Conv3D): + raise TypeError + layer = conv_cls( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + groups=groups, + ) + layer.build(input_shape) + layer.enable_lora(2) + self.assertLen(layer.trainable_weights, 3) + self.assertLen(layer.non_trainable_weights, 1) + if backend.backend() == "torch": + self.assertLen(layer.torch_params, 4) + # Try eager call + x = np.random.random((64,) + input_shape[1:]) + y = np.random.random((64,) + output_shape[1:]) + _ = layer(x[:2]) + + init_lora_a_kernel_value = layer.lora_kernel_a.numpy() + init_lora_b_kernel_value = layer.lora_kernel_b.numpy() + + # Try calling fit() + model = models.Sequential([layer]) + model.compile(optimizer="sgd", loss="mse") + model.fit(x, y) + + final_lora_a_kernel_value = layer.lora_kernel_a.numpy() + final_lora_b_kernel_value = layer.lora_kernel_b.numpy() + diff_a = np.max( + np.abs(init_lora_a_kernel_value - final_lora_a_kernel_value) + ) + diff_b = np.max( + np.abs(init_lora_b_kernel_value - final_lora_b_kernel_value) + ) + self.assertGreater(diff_a, 0.0) + self.assertGreater(diff_b, 0.0) + + # Try saving and reloading the model + temp_filepath = os.path.join(self.get_temp_dir(), "lora_model.keras") + model.save(temp_filepath) + + new_model = saving.load_model(temp_filepath) + self.assertTrue(new_model.layers[0].lora_enabled) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + # Try saving and reloading the model's weights only + temp_filepath = os.path.join( + self.get_temp_dir(), "lora_model.weights.h5" + ) + model.save_weights(temp_filepath) + + # Load the file into a fresh, non-lora model + new_model = models.Sequential( + [ + conv_cls( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + groups=groups, + ) + ] + ) + new_model.build(input_shape) + new_model.load_weights(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + # Try loading a normal checkpoint into a lora model + new_model.save_weights(temp_filepath) + model.load_weights(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + @pytest.mark.requires_trainable_backend + def test_lora_weight_name(self): + + class MyModel(models.Model): + def __init__(self): + super().__init__(name="mymodel") + self.conv2d = layers.Conv2D(4, 3, name="conv2d") + + def build(self, input_shape): + self.conv2d.build(input_shape) + + def call(self, x): + return self.conv2d(x) + + model = MyModel() + model.build((None, 5, 5, 4)) + model.conv2d.enable_lora(2) + self.assertEqual( + model.conv2d.lora_kernel_a.path, "mymodel/conv2d/lora_kernel_a" + ) + + @pytest.mark.requires_trainable_backend + def test_lora_rank_argument(self): + self.run_layer_test( + layers.Conv2D, + init_kwargs={ + "filters": 5, + "kernel_size": 3, + "activation": "sigmoid", + "data_format": "channels_last", + "kernel_regularizer": "l2", + "lora_rank": 2, + }, + input_shape=(2, 5, 5, 4), + expected_output_shape=(2, 3, 3, 5), + expected_num_trainable_weights=3, + expected_num_non_trainable_weights=1, + expected_num_seed_generators=0, + expected_num_losses=2, # we have 2 regularizers. + supports_masking=False, + ) + class ConvCorrectnessTest(testing.TestCase, parameterized.TestCase): @parameterized.parameters(