From 01cb3e1e6cd98cc83bd849aac838340310159e7b Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Sun, 14 Apr 2024 21:22:09 +0800 Subject: [PATCH 1/5] Add LoRA to `BaseConv` --- keras/layers/convolutional/base_conv.py | 71 +++++++++++++++++++++---- 1 file changed, 62 insertions(+), 9 deletions(-) diff --git a/keras/layers/convolutional/base_conv.py b/keras/layers/convolutional/base_conv.py index eebbd53fc1c..eca1c7bcda9 100644 --- a/keras/layers/convolutional/base_conv.py +++ b/keras/layers/convolutional/base_conv.py @@ -1,5 +1,7 @@ """Keras base class for convolution layers.""" +import math + from keras import activations from keras import constraints from keras import initializers @@ -92,16 +94,10 @@ def __init__( activity_regularizer=None, kernel_constraint=None, bias_constraint=None, - trainable=True, - name=None, + lora_rank=None, **kwargs, ): - super().__init__( - trainable=trainable, - name=name, - activity_regularizer=activity_regularizer, - **kwargs, - ) + super().__init__(activity_regularizer=activity_regularizer, **kwargs) self.rank = rank self.filters = filters self.groups = groups @@ -120,6 +116,8 @@ def __init__( self.bias_regularizer = regularizers.get(bias_regularizer) self.kernel_constraint = constraints.get(kernel_constraint) self.bias_constraint = constraints.get(bias_constraint) + self.lora_rank = lora_rank + self.lora_enabled = False self.input_spec = InputSpec(min_ndim=self.rank + 2) self.data_format = self.data_format @@ -187,7 +185,7 @@ def build(self, input_shape): # shape, and make sure the output shape has all positive dimensions. self.compute_output_shape(input_shape) - self.kernel = self.add_weight( + self._kernel = self.add_weight( name="kernel", shape=kernel_shape, initializer=self.kernel_initializer, @@ -210,6 +208,19 @@ def build(self, input_shape): self.bias = None self.built = True + @property + def kernel(self): + if not self.built: + raise AttributeError( + "You must build the layer before accessing `kernel`." + ) + if self.lora_enabled: + return self._kernel + ops.reshape( + ops.matmul(self.lora_kernel_a, self.lora_kernel_b), + self._kernel.shape, + ) + return self._kernel + def convolution_op(self, inputs, kernel): return ops.conv( inputs, @@ -248,6 +259,46 @@ def compute_output_shape(self, input_shape): dilation_rate=self.dilation_rate, ) + def enable_lora( + self, rank, a_initializer="he_uniform", b_initializer="zeros" + ): + if self.kernel_constraint: + raise ValueError( + "Lora is incompatible with kernel constraints. " + "In order to enable lora on this layer, remove the " + "`kernel_constraint` argument." + ) + if not self.built: + raise ValueError( + "Cannot enable lora on a layer that isn't yet built." + ) + if self.lora_enabled: + raise ValueError( + "lora is already enabled. " + "This can only be done once per layer." + ) + if self.groups != 1: + raise ValueError + + self._tracker.unlock() + input_channel = self._kernel.shape[-2] + self.lora_kernel_a = self.add_weight( + name="lora_kernel_a", + shape=(input_channel * math.prod(self.kernel_size), rank), + initializer=initializers.get(a_initializer), + regularizer=self.kernel_regularizer, + ) + self.lora_kernel_b = self.add_weight( + name="lora_kernel_b", + shape=(rank, self.filters), + initializer=initializers.get(b_initializer), + regularizer=self.kernel_regularizer, + ) + self._kernel.trainable = False + self._tracker.lock() + self.lora_enabled = True + self.lora_rank = rank + def get_config(self): config = super().get_config() config.update( @@ -282,4 +333,6 @@ def get_config(self): "bias_constraint": constraints.serialize(self.bias_constraint), } ) + if self.lora_rank: + config["lora_rank"] = self.lora_rank return config From 6a94c95b77ad8427ab271009e4997b0872a425da Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Mon, 15 Apr 2024 22:05:57 +0800 Subject: [PATCH 2/5] Add tests --- keras/layers/convolutional/base_conv.py | 79 +++++++-- keras/layers/convolutional/conv_test.py | 209 ++++++++++++++++++++++++ 2 files changed, 278 insertions(+), 10 deletions(-) diff --git a/keras/layers/convolutional/base_conv.py b/keras/layers/convolutional/base_conv.py index eca1c7bcda9..689083d9400 100644 --- a/keras/layers/convolutional/base_conv.py +++ b/keras/layers/convolutional/base_conv.py @@ -1,7 +1,5 @@ """Keras base class for convolution layers.""" -import math - from keras import activations from keras import constraints from keras import initializers @@ -73,6 +71,15 @@ class BaseConv(Layer): are not safe to use when doing asynchronous distributed training. bias_constraint: Optional projection function to be applied to the bias after being updated by an `Optimizer`. + lora_rank: Optional integer. If set, the layer's forward pass + will implement LoRA (Low-Rank Adaptation) + with the provided rank. LoRA sets the layer's kernel + to non-trainable and replaces it with a delta over the + original kernel, obtained via multiplying two lower-rank + trainable matrices. This can be useful to reduce the + computation cost of fine-tuning large dense layers. + You can also enable LoRA on an existing layer by calling + `layer.enable_lora(rank)`. """ def __init__( @@ -207,6 +214,8 @@ def build(self, input_shape): else: self.bias = None self.built = True + if self.lora_rank: + self.enable_lora(self.lora_rank) @property def kernel(self): @@ -215,9 +224,8 @@ def kernel(self): "You must build the layer before accessing `kernel`." ) if self.lora_enabled: - return self._kernel + ops.reshape( - ops.matmul(self.lora_kernel_a, self.lora_kernel_b), - self._kernel.shape, + return self._kernel + ops.matmul( + self.lora_kernel_a, self.lora_kernel_b ) return self._kernel @@ -277,14 +285,10 @@ def enable_lora( "lora is already enabled. " "This can only be done once per layer." ) - if self.groups != 1: - raise ValueError - self._tracker.unlock() - input_channel = self._kernel.shape[-2] self.lora_kernel_a = self.add_weight( name="lora_kernel_a", - shape=(input_channel * math.prod(self.kernel_size), rank), + shape=self._kernel.shape[:-1] + (rank,), initializer=initializers.get(a_initializer), regularizer=self.kernel_regularizer, ) @@ -299,6 +303,27 @@ def enable_lora( self.lora_enabled = True self.lora_rank = rank + def save_own_variables(self, store): + # Do nothing if the layer isn't yet built + if not self.built: + return + store["0"] = self.kernel + if self.use_bias: + store["1"] = self.bias + + def load_own_variables(self, store): + if not self.lora_enabled: + self._check_load_own_variables(store) + # Do nothing if the layer isn't yet built + if not self.built: + return + self._kernel.assign(store["0"]) + if self.use_bias: + self.bias.assign(store["1"]) + if self.lora_enabled: + self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape)) + self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape)) + def get_config(self): config = super().get_config() config.update( @@ -336,3 +361,37 @@ def get_config(self): if self.lora_rank: config["lora_rank"] = self.lora_rank return config + + def _check_load_own_variables(self, store): + all_vars = self._trainable_variables + self._non_trainable_variables + if len(store.keys()) != len(all_vars): + if len(all_vars) == 0 and not self.built: + raise ValueError( + f"Layer '{self.name}' was never built " + "and thus it doesn't have any variables. " + f"However the weights file lists {len(store.keys())} " + "variables for this layer.\n" + "In most cases, this error indicates that either:\n\n" + "1. The layer is owned by a parent layer that " + "implements a `build()` method, but calling the " + "parent's `build()` method did NOT create the state of " + f"the child layer '{self.name}'. A `build()` method " + "must create ALL state for the layer, including " + "the state of any children layers.\n\n" + "2. You need to implement " + "the `def build_from_config(self, config)` method " + f"on layer '{self.name}', to specify how to rebuild " + "it during loading. " + "In this case, you might also want to implement the " + "method that generates the build config at saving time, " + "`def get_build_config(self)`. " + "The method `build_from_config()` is meant " + "to create the state " + "of the layer (i.e. its variables) upon deserialization.", + ) + raise ValueError( + f"Layer '{self.name}' expected {len(all_vars)} variables, " + "but received " + f"{len(store.keys())} variables during loading. " + f"Expected: {[v.name for v in all_vars]}" + ) diff --git a/keras/layers/convolutional/conv_test.py b/keras/layers/convolutional/conv_test.py index 62ec5de8b0e..5a2a80b6111 100644 --- a/keras/layers/convolutional/conv_test.py +++ b/keras/layers/convolutional/conv_test.py @@ -1,10 +1,15 @@ +import os + import numpy as np import pytest from absl.testing import parameterized from numpy.lib.stride_tricks import as_strided +from keras import backend from keras import constraints from keras import layers +from keras import models +from keras import saving from keras import testing @@ -538,6 +543,210 @@ def test_bad_init_args(self): ): layers.Conv2D(filters=5, kernel_size=(2, 2), groups=2) + @parameterized.named_parameters( + { + "testcase_name": "conv1d_kernel_size3_strides1", + "conv_cls": layers.Conv1D, + "filters": 6, + "kernel_size": 3, + "strides": 1, + "padding": "valid", + "dilation_rate": 1, + "groups": 1, + "input_shape": (None, 5, 4), + "output_shape": (None, 3, 6), + }, + { + "testcase_name": "conv1d_kernel_size1_strides2", + "conv_cls": layers.Conv1D, + "filters": 6, + "kernel_size": 2, + "strides": 2, + "padding": "valid", + "dilation_rate": 1, + "groups": 2, + "input_shape": (None, 5, 4), + "output_shape": (None, 2, 6), + }, + { + "testcase_name": "conv2d_kernel_size3_strides1", + "conv_cls": layers.Conv2D, + "filters": 6, + "kernel_size": 3, + "strides": 1, + "padding": "valid", + "dilation_rate": 1, + "groups": 1, + "input_shape": (None, 5, 5, 4), + "output_shape": (None, 3, 3, 6), + }, + { + "testcase_name": "conv2d_kernel_size1_strides2", + "conv_cls": layers.Conv2D, + "filters": 6, + "kernel_size": 2, + "strides": 2, + "padding": "valid", + "dilation_rate": 1, + "groups": 2, + "input_shape": (None, 5, 5, 4), + "output_shape": (None, 2, 2, 6), + }, + { + "testcase_name": "conv3d_kernel_size3_strides1", + "conv_cls": layers.Conv3D, + "filters": 6, + "kernel_size": 3, + "strides": 1, + "padding": "valid", + "dilation_rate": 1, + "groups": 1, + "input_shape": (None, 5, 5, 5, 4), + "output_shape": (None, 3, 3, 3, 6), + }, + { + "testcase_name": "conv3d_kernel_size1_strides2", + "conv_cls": layers.Conv3D, + "filters": 6, + "kernel_size": 2, + "strides": 2, + "padding": "valid", + "dilation_rate": 1, + "groups": 2, + "input_shape": (None, 5, 5, 5, 4), + "output_shape": (None, 2, 2, 2, 6), + }, + ) + @pytest.mark.requires_trainable_backend + def test_enable_lora( + self, + conv_cls, + filters, + kernel_size, + strides, + padding, + dilation_rate, + groups, + input_shape, + output_shape, + ): + if conv_cls not in (layers.Conv1D, layers.Conv2D, layers.Conv3D): + raise TypeError + layer = conv_cls( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + dilation_rate=dilation_rate, + groups=groups, + ) + layer.build(input_shape) + layer.enable_lora(2) + self.assertLen(layer.trainable_weights, 3) + self.assertLen(layer.non_trainable_weights, 1) + if backend.backend() == "torch": + self.assertLen(layer.torch_params, 4) + # Try eager call + x = np.random.random((64,) + input_shape[1:]) + y = np.random.random((64,) + output_shape[1:]) + _ = layer(x[:2]) + + init_lora_a_kernel_value = layer.lora_kernel_a.numpy() + init_lora_b_kernel_value = layer.lora_kernel_b.numpy() + + # Try calling fit() + model = models.Sequential([layer]) + model.compile(optimizer="sgd", loss="mse") + model.fit(x, y) + + final_lora_a_kernel_value = layer.lora_kernel_a.numpy() + final_lora_b_kernel_value = layer.lora_kernel_b.numpy() + diff_a = np.max( + np.abs(init_lora_a_kernel_value - final_lora_a_kernel_value) + ) + diff_b = np.max( + np.abs(init_lora_b_kernel_value - final_lora_b_kernel_value) + ) + self.assertGreater(diff_a, 0.0) + self.assertGreater(diff_b, 0.0) + + # Try saving and reloading the model + temp_filepath = os.path.join(self.get_temp_dir(), "lora_model.keras") + model.save(temp_filepath) + + new_model = saving.load_model(temp_filepath) + self.assertTrue(new_model.layers[0].lora_enabled) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + # Try saving and reloading the model's weights only + temp_filepath = os.path.join( + self.get_temp_dir(), "lora_model.weights.h5" + ) + model.save_weights(temp_filepath) + + # Load the file into a fresh, non-lora model + new_model = models.Sequential( + [ + conv_cls( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + dilation_rate=dilation_rate, + groups=groups, + ) + ] + ) + new_model.build(input_shape) + new_model.load_weights(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + # Try loading a normal checkpoint into a lora model + new_model.save_weights(temp_filepath) + model.load_weights(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + @pytest.mark.requires_trainable_backend + def test_lora_weight_name(self): + + class MyModel(models.Model): + def __init__(self): + super().__init__(name="mymodel") + self.conv2d = layers.Conv2D(4, 3, name="conv2d") + + def build(self, input_shape): + self.conv2d.build(input_shape) + + def call(self, x): + return self.conv2d(x) + + model = MyModel() + model.build((None, 5, 5, 4)) + model.conv2d.enable_lora(2) + self.assertEqual( + model.conv2d.lora_kernel_a.path, "mymodel/conv2d/lora_kernel_a" + ) + + @pytest.mark.requires_trainable_backend + def test_lora_rank_argument(self): + self.run_layer_test( + layers.Conv2D, + init_kwargs={ + "filters": 5, + "kernel_size": 3, + "activation": "sigmoid", + "kernel_regularizer": "l2", + "lora_rank": 2, + }, + input_shape=(2, 5, 5, 4), + expected_output_shape=(2, 3, 3, 5), + expected_num_trainable_weights=3, + expected_num_non_trainable_weights=1, + expected_num_seed_generators=0, + expected_num_losses=2, # we have 2 regularizers. + supports_masking=False, + ) + class ConvCorrectnessTest(testing.TestCase, parameterized.TestCase): @parameterized.parameters( From 00ba3491489500d6286089117071cecb9c2a99fd Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Mon, 15 Apr 2024 22:07:25 +0800 Subject: [PATCH 3/5] Fix typo --- keras/layers/convolutional/conv_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/keras/layers/convolutional/conv_test.py b/keras/layers/convolutional/conv_test.py index 5a2a80b6111..72222398945 100644 --- a/keras/layers/convolutional/conv_test.py +++ b/keras/layers/convolutional/conv_test.py @@ -557,7 +557,7 @@ def test_bad_init_args(self): "output_shape": (None, 3, 6), }, { - "testcase_name": "conv1d_kernel_size1_strides2", + "testcase_name": "conv1d_kernel_size2_strides2", "conv_cls": layers.Conv1D, "filters": 6, "kernel_size": 2, @@ -581,7 +581,7 @@ def test_bad_init_args(self): "output_shape": (None, 3, 3, 6), }, { - "testcase_name": "conv2d_kernel_size1_strides2", + "testcase_name": "conv2d_kernel_size2_strides2", "conv_cls": layers.Conv2D, "filters": 6, "kernel_size": 2, @@ -605,7 +605,7 @@ def test_bad_init_args(self): "output_shape": (None, 3, 3, 3, 6), }, { - "testcase_name": "conv3d_kernel_size1_strides2", + "testcase_name": "conv3d_kernel_size2_strides2", "conv_cls": layers.Conv3D, "filters": 6, "kernel_size": 2, From a6c6192420fb02a10f6ddf6136578c084ce2eb62 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Mon, 15 Apr 2024 22:35:44 +0800 Subject: [PATCH 4/5] Fix tests --- keras/layers/convolutional/conv_test.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/keras/layers/convolutional/conv_test.py b/keras/layers/convolutional/conv_test.py index 72222398945..b4a4b32783a 100644 --- a/keras/layers/convolutional/conv_test.py +++ b/keras/layers/convolutional/conv_test.py @@ -551,6 +551,7 @@ def test_bad_init_args(self): "kernel_size": 3, "strides": 1, "padding": "valid", + "data_format": "channels_last", "dilation_rate": 1, "groups": 1, "input_shape": (None, 5, 4), @@ -563,6 +564,7 @@ def test_bad_init_args(self): "kernel_size": 2, "strides": 2, "padding": "valid", + "data_format": "channels_last", "dilation_rate": 1, "groups": 2, "input_shape": (None, 5, 4), @@ -575,6 +577,7 @@ def test_bad_init_args(self): "kernel_size": 3, "strides": 1, "padding": "valid", + "data_format": "channels_last", "dilation_rate": 1, "groups": 1, "input_shape": (None, 5, 5, 4), @@ -587,6 +590,7 @@ def test_bad_init_args(self): "kernel_size": 2, "strides": 2, "padding": "valid", + "data_format": "channels_last", "dilation_rate": 1, "groups": 2, "input_shape": (None, 5, 5, 4), @@ -599,6 +603,7 @@ def test_bad_init_args(self): "kernel_size": 3, "strides": 1, "padding": "valid", + "data_format": "channels_last", "dilation_rate": 1, "groups": 1, "input_shape": (None, 5, 5, 5, 4), @@ -611,6 +616,7 @@ def test_bad_init_args(self): "kernel_size": 2, "strides": 2, "padding": "valid", + "data_format": "channels_last", "dilation_rate": 1, "groups": 2, "input_shape": (None, 5, 5, 5, 4), @@ -625,6 +631,7 @@ def test_enable_lora( kernel_size, strides, padding, + data_format, dilation_rate, groups, input_shape, @@ -637,6 +644,7 @@ def test_enable_lora( kernel_size=kernel_size, strides=strides, padding=padding, + data_format=data_format, dilation_rate=dilation_rate, groups=groups, ) @@ -692,6 +700,7 @@ def test_enable_lora( kernel_size=kernel_size, strides=strides, padding=padding, + data_format=data_format, dilation_rate=dilation_rate, groups=groups, ) From f0136d182f51164918eb21fc8cd44dc76ac70799 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Mon, 15 Apr 2024 23:02:29 +0800 Subject: [PATCH 5/5] Fix tests --- keras/layers/convolutional/conv_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras/layers/convolutional/conv_test.py b/keras/layers/convolutional/conv_test.py index b4a4b32783a..281e35a960e 100644 --- a/keras/layers/convolutional/conv_test.py +++ b/keras/layers/convolutional/conv_test.py @@ -744,6 +744,7 @@ def test_lora_rank_argument(self): "filters": 5, "kernel_size": 3, "activation": "sigmoid", + "data_format": "channels_last", "kernel_regularizer": "l2", "lora_rank": 2, },