Skip to content

Commit

Permalink
make LoRA generic
Browse files Browse the repository at this point in the history
  • Loading branch information
catwell committed Feb 6, 2024
1 parent 593a0f5 commit 60d5d0c
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 32 deletions.
41 changes: 21 additions & 20 deletions src/refiners/fluxion/adapters/lora.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Generic, TypeVar, cast

from torch import Tensor, device as Device, dtype as DType
from torch.nn import Parameter as TorchParameter
Expand All @@ -7,8 +8,10 @@
import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter

T = TypeVar("T", bound=fl.WeightedModule)

class Lora(fl.Chain, ABC):

class Lora(Generic[T], fl.Chain, ABC):
"""Low-Rank Adaptation (LoRA) layer.
This layer is composed of two [`WeightedModule`][refiners.fluxion.layers.WeightedModule]:
Expand Down Expand Up @@ -55,9 +58,7 @@ def __init__(
zeros_(tensor=self.up.weight)

@abstractmethod
def lora_layers(
self, device: Device | str | None = None, dtype: DType | None = None
) -> tuple[fl.WeightedModule, fl.WeightedModule]:
def lora_layers(self, device: Device | str | None = None, dtype: DType | None = None) -> tuple[T, T]:
"""Create the down and up layers of the LoRA.
Args:
Expand All @@ -67,18 +68,18 @@ def lora_layers(
...

@property
def down(self) -> fl.WeightedModule:
def down(self) -> T:
"""The down layer."""
down_layer = self[0]
assert isinstance(down_layer, fl.WeightedModule)
return down_layer
return cast(T, down_layer)

@property
def up(self) -> fl.WeightedModule:
def up(self) -> T:
"""The up layer."""
up_layer = self[1]
assert isinstance(up_layer, fl.WeightedModule)
return up_layer
return cast(T, up_layer)

@property
def rank(self) -> int:
Expand All @@ -102,7 +103,7 @@ def from_weights(
/,
down: Tensor,
up: Tensor,
) -> "Lora":
) -> "Lora[Any]":
match (up.ndim, down.ndim):
case (2, 2):
return LinearLora.from_weights(name, up=up, down=down)
Expand All @@ -112,14 +113,14 @@ def from_weights(
raise ValueError(f"Unsupported weight shapes: up={up.shape}, down={down.shape}")

@classmethod
def from_dict(cls, name: str, /, state_dict: dict[str, Tensor]) -> dict[str, "Lora"]:
def from_dict(cls, name: str, /, state_dict: dict[str, Tensor]) -> dict[str, "Lora[Any]"]:
"""
Create a dictionary of LoRA layers from a state dict.
Expects the state dict to be a succession of down and up weights.
"""
state_dict = {k: v for k, v in state_dict.items() if ".weight" in k}
loras: dict[str, Lora] = {}
loras: dict[str, Lora[Any]] = {}
for down_key, down_tensor, up_tensor in zip(
list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2]
):
Expand Down Expand Up @@ -168,7 +169,7 @@ def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
self.up.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))


class LinearLora(Lora):
class LinearLora(Lora[fl.Linear]):
"""Low-Rank Adaptation (LoRA) layer for linear layers.
This layer uses two [`Linear`][refiners.fluxion.layers.Linear] layers as its down and up layers.
Expand Down Expand Up @@ -254,7 +255,7 @@ def is_compatible(self, layer: fl.WeightedModule, /) -> bool:
return False


class Conv2dLora(Lora):
class Conv2dLora(Lora[fl.Conv2d]):
"""Low-Rank Adaptation (LoRA) layer for 2D convolutional layers.
This layer uses two [`Conv2d`][refiners.fluxion.layers.Conv2d] layers as its down and up layers.
Expand Down Expand Up @@ -374,7 +375,7 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
This adapter simply sums the target layer with the given LoRA layers.
"""

def __init__(self, target: fl.WeightedModule, /, *loras: Lora) -> None:
def __init__(self, target: fl.WeightedModule, /, *loras: Lora[Any]) -> None:
"""Initialize the adapter.
Args:
Expand All @@ -387,24 +388,24 @@ def __init__(self, target: fl.WeightedModule, /, *loras: Lora) -> None:
@property
def names(self) -> list[str]:
"""The names of the LoRA layers."""
return [lora.name for lora in self.layers(Lora)]
return [lora.name for lora in self.layers(Lora[Any])]

@property
def loras(self) -> dict[str, Lora]:
def loras(self) -> dict[str, Lora[Any]]:
"""The LoRA layers indexed by name."""
return {lora.name: lora for lora in self.layers(Lora)}
return {lora.name: lora for lora in self.layers(Lora[Any])}

@property
def scales(self) -> dict[str, float]:
"""The scales of the LoRA layers indexed by names."""
return {lora.name: lora.scale for lora in self.layers(Lora)}
return {lora.name: lora.scale for lora in self.layers(Lora[Any])}

@scales.setter
def scale(self, values: dict[str, float]) -> None:
for name, value in values.items():
self.loras[name].scale = value

def add_lora(self, lora: Lora, /) -> None:
def add_lora(self, lora: Lora[Any], /) -> None:
"""Add a LoRA layer to the adapter.
Raises:
Expand All @@ -416,7 +417,7 @@ def add_lora(self, lora: Lora, /) -> None:
assert lora.name not in self.names, f"LoRA layer with name {lora.name} already exists"
self.append(lora)

def remove_lora(self, name: str, /) -> Lora | None:
def remove_lora(self, name: str, /) -> Lora[Any] | None:
"""Remove a LoRA layer from the adapter.
Note:
Expand Down
15 changes: 8 additions & 7 deletions src/refiners/foundationals/latent_diffusion/lora.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Any
from warnings import warn

from torch import Tensor
Expand Down Expand Up @@ -106,7 +107,7 @@ def add_multiple_loras(
for name, lora_tensors in tensors.items():
self.add_loras(name, tensors=lora_tensors, scale=scale[name] if scale else 1.0)

def add_loras_to_text_encoder(self, loras: dict[str, Lora], /) -> None:
def add_loras_to_text_encoder(self, loras: dict[str, Lora[Any]], /) -> None:
"""Add multiple LoRAs to the text encoder.
Args:
Expand All @@ -116,7 +117,7 @@ def add_loras_to_text_encoder(self, loras: dict[str, Lora], /) -> None:
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder)

def add_loras_to_unet(self, loras: dict[str, Lora], /) -> None:
def add_loras_to_unet(self, loras: dict[str, Lora[Any]], /) -> None:
"""Add multiple LoRAs to the U-Net.
Args:
Expand Down Expand Up @@ -147,7 +148,7 @@ def remove_all(self) -> None:
for lora_adapter in self.lora_adapters:
lora_adapter.eject()

def get_loras_by_name(self, name: str, /) -> list[Lora]:
def get_loras_by_name(self, name: str, /) -> list[Lora[Any]]:
"""Get the LoRA layers with the given name.
Args:
Expand Down Expand Up @@ -190,9 +191,9 @@ def update_scales(self, scales: dict[str, float], /) -> None:
lora.scale = scale

@property
def loras(self) -> list[Lora]:
def loras(self) -> list[Lora[Any]]:
"""List of all the LoRA layers managed by the SDLoraManager."""
return list(self.unet.layers(Lora)) + list(self.clip_text_encoder.layers(Lora))
return list(self.unet.layers(Lora[Any])) + list(self.clip_text_encoder.layers(Lora[Any]))

@property
def names(self) -> list[str]:
Expand Down Expand Up @@ -239,12 +240,12 @@ def sort_keys(key: str, /) -> tuple[str, int]:

@staticmethod
def auto_attach(
loras: dict[str, Lora],
loras: dict[str, Lora[Any]],
target: fl.Chain,
/,
exclude: list[str] | None = None,
) -> None:
failed_loras: dict[str, Lora] = {}
failed_loras: dict[str, Lora[Any]] = {}
for key, lora in loras.items():
if attach := lora.auto_attach(target, exclude=exclude):
adapter, parent = attach
Expand Down
7 changes: 2 additions & 5 deletions tests/adapters/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ def lora() -> LinearLora:


@pytest.fixture
def conv_lora() -> Lora:
def conv_lora() -> Conv2dLora:
return Conv2dLora("conv_test", in_channels=16, out_channels=8, kernel_size=(3, 1), rank=4)


def test_properties(lora: LinearLora, conv_lora: Lora) -> None:
def test_properties(lora: LinearLora, conv_lora: Conv2dLora) -> None:
assert lora.name == "test"
assert lora.rank == lora.down.out_features == lora.up.in_features == 16
assert lora.scale == 1.0
Expand All @@ -27,7 +27,6 @@ def test_properties(lora: LinearLora, conv_lora: Lora) -> None:
assert conv_lora.scale == 1.0
assert conv_lora.in_channels == conv_lora.down.in_channels == 16
assert conv_lora.out_channels == conv_lora.up.out_channels == 8
assert isinstance(conv_lora.down, fl.Conv2d) and isinstance(conv_lora.up, fl.Conv2d)
assert conv_lora.kernel_size == (conv_lora.down.kernel_size[0], conv_lora.up.kernel_size[0]) == (3, 1)
# padding is set so the spatial dimensions are preserved
assert conv_lora.padding == (conv_lora.down.padding[0], conv_lora.up.padding[0]) == (0, 1)
Expand All @@ -40,12 +39,10 @@ def test_scale_setter(lora: LinearLora) -> None:


def test_from_weights(lora: LinearLora, conv_lora: Conv2dLora) -> None:
assert isinstance(lora.down, fl.Linear) and isinstance(lora.up, fl.Linear)
new_lora = LinearLora.from_weights("test", down=lora.down.weight, up=lora.up.weight)
x = torch.randn(1, 320)
assert torch.allclose(lora(x), new_lora(x))

assert isinstance(conv_lora.down, fl.Conv2d) and isinstance(conv_lora.up, fl.Conv2d)
new_conv_lora = Conv2dLora.from_weights("conv_test", down=conv_lora.down.weight, up=conv_lora.up.weight)
x = torch.randn(1, 16, 64, 64)
assert torch.allclose(conv_lora(x), new_conv_lora(x))
Expand Down

0 comments on commit 60d5d0c

Please sign in to comment.