Skip to content

Commit

Permalink
[ADD] Implement shape of a structured matrix (#81)
Browse files Browse the repository at this point in the history
* [ADD] Implement shape of a structured matrix

* [REF] Use `shape` to infer dimensions

* [DOC] Update changelog
  • Loading branch information
f-dangel authored Nov 4, 2024
1 parent 2d4a405 commit 1ac69a3
Show file tree
Hide file tree
Showing 10 changed files with 120 additions and 13 deletions.
12 changes: 11 additions & 1 deletion changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ Versioning](https://semver.org/spec/v2.0.0.html).

### Fixed

## [0.0.5] - 2024-11-04

This release adds a minor improvement of the structured matrix sub-module.

### Added

- `.shape` property of a `StructuredMatrix`
([PR](https://github.com/f-dangel/singd/pull/81))

## [0.0.4] - 2024-07-03

This release adds a new interface function to `SINGD`'s structured matrix
Expand Down Expand Up @@ -95,7 +104,8 @@ No bug fixes

Initial release

[unreleased]: https://github.com/f-dangel/singd/compare/v0.0.4...HEAD
[unreleased]: https://github.com/f-dangel/singd/compare/v0.0.5...HEAD
[0.0.5]: https://github.com/f-dangel/singd/compare/v0.0.4...v0.0.5
[0.0.4]: https://github.com/f-dangel/singd/compare/v0.0.3...v0.0.4
[0.0.3]: https://github.com/f-dangel/singd/compare/v0.0.2...v0.0.3
[0.0.2]: https://github.com/f-dangel/singd/compare/v0.0.1...v0.0.2
Expand Down
2 changes: 1 addition & 1 deletion singd/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def _update_preconditioner(self, module: Module):
# in `m_K, m_C`
scale = 0.5 * (1.0 - alpha1 if normalize_lr_cov else 1.0)

dim_K, dim_C = self.preconditioner_dims(module)
(dim_K,), (dim_C,) = set(K.shape), set(C.shape)
(dtype_K, dtype_C), dev = self._get_preconditioner_dtypes_and_device(module)

# step for m_K
Expand Down
12 changes: 11 additions & 1 deletion singd/structures/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch
import torch.distributed as dist
from torch import Tensor, zeros
from torch import Size, Tensor, zeros
from torch.linalg import matrix_norm

from singd.structures.utils import diag_add_, supported_eye
Expand Down Expand Up @@ -82,6 +82,16 @@ def named_tensors(self) -> Iterator[Tuple[str, Tensor]]:
for name in self._tensor_names:
yield name, getattr(self, name)

@property
def shape(self) -> Size:
"""Return the structured matrix's shape.
Returns:
The shape of the matrix.
"""
self._warn_naive_implementation("shape")
return self.to_dense().shape

def __matmul__(
self, other: Union[StructuredMatrix, Tensor]
) -> Union[StructuredMatrix, Tensor]:
Expand Down
18 changes: 17 additions & 1 deletion singd/structures/blockdiagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
from einops import rearrange
from torch import Tensor, arange, cat, einsum, zeros
from torch import Size, Tensor, arange, cat, einsum, zeros
from torch.linalg import vector_norm

from singd.structures.base import StructuredMatrix
Expand Down Expand Up @@ -109,6 +109,22 @@ def __init__(self, blocks: Tensor, last: Tensor) -> None:
self._last: Tensor
self.register_tensor(last, "_last")

@property
def shape(self) -> Size:
"""Return the structured matrix's shape.
Returns:
The shape of the matrix.
"""
num_blocks, _, _ = self._blocks.shape
last_rows, last_cols = self._last.shape
return Size(
(
num_blocks * self.BLOCK_DIM + last_rows,
num_blocks * self.BLOCK_DIM + last_cols,
)
)

@classmethod
def from_dense(cls, mat: Tensor) -> BlockDiagonalMatrixTemplate:
"""Construct from a PyTorch tensor.
Expand Down
11 changes: 10 additions & 1 deletion singd/structures/diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Union

import torch
from torch import Tensor, einsum, ones, zeros
from torch import Size, Tensor, einsum, ones, zeros
from torch.linalg import vector_norm

from singd.structures.base import StructuredMatrix
Expand Down Expand Up @@ -47,6 +47,15 @@ def __init__(self, mat_diag: Tensor) -> None:
self._mat_diag: Tensor
self.register_tensor(mat_diag, "_mat_diag")

@property
def shape(self) -> Size:
"""Return the structured matrix's shape.
Returns:
The shape of the matrix.
"""
return self._mat_diag.shape + self._mat_diag.shape

def __matmul__(
self, other: Union[DiagonalMatrix, Tensor]
) -> Union[DiagonalMatrix, Tensor]:
Expand Down
11 changes: 10 additions & 1 deletion singd/structures/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Tuple, Union

import torch
from torch import Tensor, arange, cat, einsum, ones, zeros
from torch import Size, Tensor, arange, cat, einsum, ones, zeros
from torch.linalg import vector_norm

from singd.structures.base import StructuredMatrix
Expand Down Expand Up @@ -142,6 +142,15 @@ def __init__(self, A: Tensor, B: Tensor, C: Tensor, D: Tensor, E: Tensor):
self.E: Tensor
self.register_tensor(E, "E")

@property
def shape(self) -> Size:
"""Return the structured matrix's shape.
Returns:
The shape of the matrix.
"""
return Size((self.dim, self.dim))

@classmethod
def from_dense(cls, sym_mat: Tensor) -> HierarchicalMatrixTemplate:
"""Construct from a PyTorch tensor.
Expand Down
30 changes: 25 additions & 5 deletions singd/structures/recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from typing import Iterator, List, Tuple, Type, Union

from torch import Tensor, block_diag
from torch import Size, Tensor, block_diag

from singd.structures.base import StructuredMatrix

Expand Down Expand Up @@ -164,8 +164,7 @@ def __init__(self, A: StructuredMatrix, B: Tensor, C: StructuredMatrix):
f"{self.CLS_C}, respectively. Got {type(A)} and {type(C)}."
)

# TODO Add a `dim` property to make this cheaper
dim_A, dim_C = A.to_dense().shape[0], C.to_dense().shape[0]
(dim_A,), (dim_C,) = set(A.shape), set(C.shape)
if B.shape != (dim_A, dim_C):
raise ValueError(f"Shape of `B` ({B.shape}) should be ({(dim_A, dim_C)}).")

Expand All @@ -184,6 +183,17 @@ def __init__(self, A: StructuredMatrix, B: Tensor, C: StructuredMatrix):
self.C: StructuredMatrix
self.register_substructure(C, "C")

@property
def shape(self) -> Size:
"""Return the structured matrix's shape.
Returns:
The shape of the matrix.
"""
A_rows, A_cols = self.A.shape
C_rows, C_cols = self.C.shape
return Size((A_rows + C_rows, A_cols + C_cols))

@classmethod
def from_dense(cls, sym_mat: Tensor) -> RecursiveTopRightMatrixTemplate:
"""Construct from a PyTorch tensor.
Expand Down Expand Up @@ -302,8 +312,7 @@ def __init__(self, A: StructuredMatrix, B: Tensor, C: StructuredMatrix):
f"{self.CLS_C}, respectively. Got {type(A)} and {type(C)}."
)

# TODO Add a `dim` property to make this cheaper
dim_A, dim_C = A.to_dense().shape[0], C.to_dense().shape[0]
(dim_A,), (dim_C,) = set(A.shape), set(C.shape)
if B.shape != (dim_C, dim_A):
raise ValueError(f"Shape of `B` ({B.shape}) should be ({(dim_A, dim_C)}).")

Expand All @@ -322,6 +331,17 @@ def __init__(self, A: StructuredMatrix, B: Tensor, C: StructuredMatrix):
self.C: StructuredMatrix
self.register_substructure(C, "C")

@property
def shape(self) -> Size:
"""Return the structured matrix's shape.
Returns:
The shape of the matrix.
"""
A_rows, A_cols = self.A.shape
C_rows, C_cols = self.C.shape
return Size((A_rows + C_rows, A_cols + C_cols))

@classmethod
def from_dense(cls, sym_mat: Tensor) -> RecursiveBottomLeftMatrixTemplate:
"""Construct from a PyTorch tensor.
Expand Down
11 changes: 10 additions & 1 deletion singd/structures/triltoeplitz.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Union

import torch
from torch import Tensor, arange, cat, zeros
from torch import Size, Tensor, arange, cat, zeros
from torch.linalg import vector_norm
from torch.nn.functional import conv1d, pad

Expand Down Expand Up @@ -55,6 +55,15 @@ def __init__(self, lower_diags: Tensor) -> None:
self._lower_diags: Tensor
self.register_tensor(lower_diags, "_lower_diags")

@property
def shape(self) -> Size:
"""Return the structured matrix's shape.
Returns:
The shape of the matrix.
"""
return self._lower_diags.shape + self._lower_diags.shape

@classmethod
def from_dense(cls, mat: Tensor) -> TrilToeplitzMatrix:
"""Construct from a PyTorch tensor.
Expand Down
11 changes: 10 additions & 1 deletion singd/structures/triutoeplitz.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Union

import torch
from torch import Tensor, arange, cat, triu_indices, zeros
from torch import Size, Tensor, arange, cat, triu_indices, zeros
from torch.linalg import vector_norm
from torch.nn.functional import conv1d, pad

Expand Down Expand Up @@ -55,6 +55,15 @@ def __init__(self, upper_diags: Tensor) -> None:
self._upper_diags: Tensor
self.register_tensor(upper_diags, "_upper_diags")

@property
def shape(self) -> Size:
"""Return the structured matrix's shape.
Returns:
The shape of the matrix.
"""
return self._upper_diags.shape + self._upper_diags.shape

@classmethod
def from_dense(cls, mat: Tensor) -> TriuToeplitzMatrix:
"""Construct from a PyTorch tensor.
Expand Down
15 changes: 15 additions & 0 deletions test/structures/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,21 @@ def test_frobenius_norm(self, dev: device, dtype: torch.dtype):
structured = self.STRUCTURED_MATRIX_CLS.from_dense(sym_mat)
report_nonclose(truth, structured.frobenius_norm())

@mark.parametrize("dtype", DTYPES, ids=DTYPE_IDS)
@mark.parametrize("dev", DEVICES, ids=DEVICE_IDS)
def test_shape(self, dev: device, dtype: torch.dtype):
"""Test shape of a structured matrix.
Args:
dev: The device on which to run the test.
dtype: The data type of the matrices.
"""
for dim in self.DIMS:
manual_seed(0)
sym_mat = symmetrize(rand((dim, dim), device=dev, dtype=dtype))
structured = self.STRUCTURED_MATRIX_CLS.from_dense(sym_mat)
assert sym_mat.shape == structured.shape

@mark.expensive
def test_visual(self):
"""Create pictures and animations of the structure.
Expand Down

0 comments on commit 1ac69a3

Please sign in to comment.