diff --git a/changelog.md b/changelog.md index 10cc4c3..7d1231a 100644 --- a/changelog.md +++ b/changelog.md @@ -16,6 +16,15 @@ Versioning](https://semver.org/spec/v2.0.0.html). ### Fixed +## [0.0.5] - 2024-11-04 + +This release adds a minor improvement of the structured matrix sub-module. + +### Added + +- `.shape` property of a `StructuredMatrix` + ([PR](https://github.com/f-dangel/singd/pull/81)) + ## [0.0.4] - 2024-07-03 This release adds a new interface function to `SINGD`'s structured matrix @@ -95,7 +104,8 @@ No bug fixes Initial release -[unreleased]: https://github.com/f-dangel/singd/compare/v0.0.4...HEAD +[unreleased]: https://github.com/f-dangel/singd/compare/v0.0.5...HEAD +[0.0.5]: https://github.com/f-dangel/singd/compare/v0.0.4...v0.0.5 [0.0.4]: https://github.com/f-dangel/singd/compare/v0.0.3...v0.0.4 [0.0.3]: https://github.com/f-dangel/singd/compare/v0.0.2...v0.0.3 [0.0.2]: https://github.com/f-dangel/singd/compare/v0.0.1...v0.0.2 diff --git a/singd/optim/optimizer.py b/singd/optim/optimizer.py index 7cad745..fb96bd2 100644 --- a/singd/optim/optimizer.py +++ b/singd/optim/optimizer.py @@ -499,7 +499,7 @@ def _update_preconditioner(self, module: Module): # in `m_K, m_C` scale = 0.5 * (1.0 - alpha1 if normalize_lr_cov else 1.0) - dim_K, dim_C = self.preconditioner_dims(module) + (dim_K,), (dim_C,) = set(K.shape), set(C.shape) (dtype_K, dtype_C), dev = self._get_preconditioner_dtypes_and_device(module) # step for m_K diff --git a/singd/structures/base.py b/singd/structures/base.py index e541c57..600774a 100644 --- a/singd/structures/base.py +++ b/singd/structures/base.py @@ -8,7 +8,7 @@ import torch import torch.distributed as dist -from torch import Tensor, zeros +from torch import Size, Tensor, zeros from torch.linalg import matrix_norm from singd.structures.utils import diag_add_, supported_eye @@ -82,6 +82,16 @@ def named_tensors(self) -> Iterator[Tuple[str, Tensor]]: for name in self._tensor_names: yield name, getattr(self, name) + @property + def shape(self) -> Size: + """Return the structured matrix's shape. + + Returns: + The shape of the matrix. + """ + self._warn_naive_implementation("shape") + return self.to_dense().shape + def __matmul__( self, other: Union[StructuredMatrix, Tensor] ) -> Union[StructuredMatrix, Tensor]: diff --git a/singd/structures/blockdiagonal.py b/singd/structures/blockdiagonal.py index d2f22f0..3aa2601 100644 --- a/singd/structures/blockdiagonal.py +++ b/singd/structures/blockdiagonal.py @@ -6,7 +6,7 @@ import torch from einops import rearrange -from torch import Tensor, arange, cat, einsum, zeros +from torch import Size, Tensor, arange, cat, einsum, zeros from torch.linalg import vector_norm from singd.structures.base import StructuredMatrix @@ -109,6 +109,22 @@ def __init__(self, blocks: Tensor, last: Tensor) -> None: self._last: Tensor self.register_tensor(last, "_last") + @property + def shape(self) -> Size: + """Return the structured matrix's shape. + + Returns: + The shape of the matrix. + """ + num_blocks, _, _ = self._blocks.shape + last_rows, last_cols = self._last.shape + return Size( + ( + num_blocks * self.BLOCK_DIM + last_rows, + num_blocks * self.BLOCK_DIM + last_cols, + ) + ) + @classmethod def from_dense(cls, mat: Tensor) -> BlockDiagonalMatrixTemplate: """Construct from a PyTorch tensor. diff --git a/singd/structures/diagonal.py b/singd/structures/diagonal.py index 838d5ac..1ff2413 100644 --- a/singd/structures/diagonal.py +++ b/singd/structures/diagonal.py @@ -5,7 +5,7 @@ from typing import Union import torch -from torch import Tensor, einsum, ones, zeros +from torch import Size, Tensor, einsum, ones, zeros from torch.linalg import vector_norm from singd.structures.base import StructuredMatrix @@ -47,6 +47,15 @@ def __init__(self, mat_diag: Tensor) -> None: self._mat_diag: Tensor self.register_tensor(mat_diag, "_mat_diag") + @property + def shape(self) -> Size: + """Return the structured matrix's shape. + + Returns: + The shape of the matrix. + """ + return self._mat_diag.shape + self._mat_diag.shape + def __matmul__( self, other: Union[DiagonalMatrix, Tensor] ) -> Union[DiagonalMatrix, Tensor]: diff --git a/singd/structures/hierarchical.py b/singd/structures/hierarchical.py index 5cb8489..4124360 100644 --- a/singd/structures/hierarchical.py +++ b/singd/structures/hierarchical.py @@ -5,7 +5,7 @@ from typing import Tuple, Union import torch -from torch import Tensor, arange, cat, einsum, ones, zeros +from torch import Size, Tensor, arange, cat, einsum, ones, zeros from torch.linalg import vector_norm from singd.structures.base import StructuredMatrix @@ -142,6 +142,15 @@ def __init__(self, A: Tensor, B: Tensor, C: Tensor, D: Tensor, E: Tensor): self.E: Tensor self.register_tensor(E, "E") + @property + def shape(self) -> Size: + """Return the structured matrix's shape. + + Returns: + The shape of the matrix. + """ + return Size((self.dim, self.dim)) + @classmethod def from_dense(cls, sym_mat: Tensor) -> HierarchicalMatrixTemplate: """Construct from a PyTorch tensor. diff --git a/singd/structures/recursive.py b/singd/structures/recursive.py index f9d6225..2d70030 100644 --- a/singd/structures/recursive.py +++ b/singd/structures/recursive.py @@ -4,7 +4,7 @@ from typing import Iterator, List, Tuple, Type, Union -from torch import Tensor, block_diag +from torch import Size, Tensor, block_diag from singd.structures.base import StructuredMatrix @@ -164,8 +164,7 @@ def __init__(self, A: StructuredMatrix, B: Tensor, C: StructuredMatrix): f"{self.CLS_C}, respectively. Got {type(A)} and {type(C)}." ) - # TODO Add a `dim` property to make this cheaper - dim_A, dim_C = A.to_dense().shape[0], C.to_dense().shape[0] + (dim_A,), (dim_C,) = set(A.shape), set(C.shape) if B.shape != (dim_A, dim_C): raise ValueError(f"Shape of `B` ({B.shape}) should be ({(dim_A, dim_C)}).") @@ -184,6 +183,17 @@ def __init__(self, A: StructuredMatrix, B: Tensor, C: StructuredMatrix): self.C: StructuredMatrix self.register_substructure(C, "C") + @property + def shape(self) -> Size: + """Return the structured matrix's shape. + + Returns: + The shape of the matrix. + """ + A_rows, A_cols = self.A.shape + C_rows, C_cols = self.C.shape + return Size((A_rows + C_rows, A_cols + C_cols)) + @classmethod def from_dense(cls, sym_mat: Tensor) -> RecursiveTopRightMatrixTemplate: """Construct from a PyTorch tensor. @@ -302,8 +312,7 @@ def __init__(self, A: StructuredMatrix, B: Tensor, C: StructuredMatrix): f"{self.CLS_C}, respectively. Got {type(A)} and {type(C)}." ) - # TODO Add a `dim` property to make this cheaper - dim_A, dim_C = A.to_dense().shape[0], C.to_dense().shape[0] + (dim_A,), (dim_C,) = set(A.shape), set(C.shape) if B.shape != (dim_C, dim_A): raise ValueError(f"Shape of `B` ({B.shape}) should be ({(dim_A, dim_C)}).") @@ -322,6 +331,17 @@ def __init__(self, A: StructuredMatrix, B: Tensor, C: StructuredMatrix): self.C: StructuredMatrix self.register_substructure(C, "C") + @property + def shape(self) -> Size: + """Return the structured matrix's shape. + + Returns: + The shape of the matrix. + """ + A_rows, A_cols = self.A.shape + C_rows, C_cols = self.C.shape + return Size((A_rows + C_rows, A_cols + C_cols)) + @classmethod def from_dense(cls, sym_mat: Tensor) -> RecursiveBottomLeftMatrixTemplate: """Construct from a PyTorch tensor. diff --git a/singd/structures/triltoeplitz.py b/singd/structures/triltoeplitz.py index dd29863..84f1f8f 100644 --- a/singd/structures/triltoeplitz.py +++ b/singd/structures/triltoeplitz.py @@ -5,7 +5,7 @@ from typing import Union import torch -from torch import Tensor, arange, cat, zeros +from torch import Size, Tensor, arange, cat, zeros from torch.linalg import vector_norm from torch.nn.functional import conv1d, pad @@ -55,6 +55,15 @@ def __init__(self, lower_diags: Tensor) -> None: self._lower_diags: Tensor self.register_tensor(lower_diags, "_lower_diags") + @property + def shape(self) -> Size: + """Return the structured matrix's shape. + + Returns: + The shape of the matrix. + """ + return self._lower_diags.shape + self._lower_diags.shape + @classmethod def from_dense(cls, mat: Tensor) -> TrilToeplitzMatrix: """Construct from a PyTorch tensor. diff --git a/singd/structures/triutoeplitz.py b/singd/structures/triutoeplitz.py index 1a6205d..2e5d4a0 100644 --- a/singd/structures/triutoeplitz.py +++ b/singd/structures/triutoeplitz.py @@ -5,7 +5,7 @@ from typing import Union import torch -from torch import Tensor, arange, cat, triu_indices, zeros +from torch import Size, Tensor, arange, cat, triu_indices, zeros from torch.linalg import vector_norm from torch.nn.functional import conv1d, pad @@ -55,6 +55,15 @@ def __init__(self, upper_diags: Tensor) -> None: self._upper_diags: Tensor self.register_tensor(upper_diags, "_upper_diags") + @property + def shape(self) -> Size: + """Return the structured matrix's shape. + + Returns: + The shape of the matrix. + """ + return self._upper_diags.shape + self._upper_diags.shape + @classmethod def from_dense(cls, mat: Tensor) -> TriuToeplitzMatrix: """Construct from a PyTorch tensor. diff --git a/test/structures/utils.py b/test/structures/utils.py index ca4c5a0..39dc37e 100644 --- a/test/structures/utils.py +++ b/test/structures/utils.py @@ -622,6 +622,21 @@ def test_frobenius_norm(self, dev: device, dtype: torch.dtype): structured = self.STRUCTURED_MATRIX_CLS.from_dense(sym_mat) report_nonclose(truth, structured.frobenius_norm()) + @mark.parametrize("dtype", DTYPES, ids=DTYPE_IDS) + @mark.parametrize("dev", DEVICES, ids=DEVICE_IDS) + def test_shape(self, dev: device, dtype: torch.dtype): + """Test shape of a structured matrix. + + Args: + dev: The device on which to run the test. + dtype: The data type of the matrices. + """ + for dim in self.DIMS: + manual_seed(0) + sym_mat = symmetrize(rand((dim, dim), device=dev, dtype=dtype)) + structured = self.STRUCTURED_MATRIX_CLS.from_dense(sym_mat) + assert sym_mat.shape == structured.shape + @mark.expensive def test_visual(self): """Create pictures and animations of the structure.