Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADD] Efficient implementation of frobenius_norm for all structures. #73

Merged
merged 15 commits into from
Jun 23, 2024
Merged
10 changes: 10 additions & 0 deletions singd/structures/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torch.distributed as dist
from torch import Tensor, zeros
from torch.linalg import matrix_norm

from singd.structures.utils import diag_add_, supported_eye

Expand Down Expand Up @@ -343,6 +344,15 @@ def infinity_vector_norm(self) -> Tensor:
# NOTE `.max` can only be called on tensors with non-zero shape
return max(t.abs().max() for _, t in self.named_tensors() if t.numel() > 0)

def frobenius_norm(self) -> Tensor:
"""Compute the Frobenius norm of the represented matrix.

Returns:
The Frobenius norm of the represented matrix.
"""
self._warn_naive_implementation("frobenius_norm")
return matrix_norm(self.to_dense())

###############################################################################
# Special initialization operations #
###############################################################################
Expand Down
11 changes: 11 additions & 0 deletions singd/structures/blockdiagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from einops import rearrange
from torch import Tensor, arange, cat, einsum, zeros
from torch.linalg import vector_norm

from singd.structures.base import StructuredMatrix
from singd.structures.utils import lowest_precision, supported_eye
Expand Down Expand Up @@ -313,6 +314,16 @@ def diag_add_(self, value: float) -> BlockDiagonalMatrixTemplate:

return self

def frobenius_norm(self) -> Tensor:
"""Compute the Frobenius norm of the represented matrix.

Returns:
The Frobenius norm of the represented matrix.
"""
return vector_norm(
cat([t.flatten() for _, t in self.named_tensors() if t.numel() > 0])
)

###############################################################################
# Special initialization operations #
###############################################################################
Expand Down
9 changes: 9 additions & 0 deletions singd/structures/diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
from torch import Tensor, einsum, ones, zeros
from torch.linalg import vector_norm

from singd.structures.base import StructuredMatrix

Expand Down Expand Up @@ -179,6 +180,14 @@ def diag_add_(self, value: float) -> DiagonalMatrix:
self._mat_diag.add_(value)
return self

def frobenius_norm(self) -> Tensor:
"""Compute the Frobenius norm of the represented matrix.

Returns:
The Frobenius norm of the represented matrix.
"""
return vector_norm(self._mat_diag)

###############################################################################
# Special initialization operations #
###############################################################################
Expand Down
11 changes: 11 additions & 0 deletions singd/structures/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
from torch import Tensor, arange, cat, einsum, ones, zeros
from torch.linalg import vector_norm

from singd.structures.base import StructuredMatrix
from singd.structures.utils import diag_add_, lowest_precision, supported_eye
Expand Down Expand Up @@ -353,6 +354,16 @@ def diag_add_(self, value: float) -> HierarchicalMatrixTemplate:
diag_add_(self.E, value)
return self

def frobenius_norm(self) -> Tensor:
"""Compute the Frobenius norm of the represented matrix.

Returns:
The Frobenius norm of the represented matrix.
"""
return vector_norm(
cat([t.flatten() for _, t in self.named_tensors() if t.numel() > 0])
)

###############################################################################
# Special initialization operations #
###############################################################################
Expand Down
37 changes: 33 additions & 4 deletions singd/structures/recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,47 @@ def register_substructure(self, substructure: StructuredMatrix, name: str) -> No
setattr(self, name, substructure)
self._substructure_names.append(name)

def named_tensors(self) -> Iterator[Tuple[str, Tensor]]:
def named_tensors(
self, include_substructures: bool = True
) -> Iterator[Tuple[str, Tensor]]:
"""Yield all tensors that represent the matrix and their names.

Args:
include_substructures: If `True`, also include the tensors of the
substructures. If `False`, exclude them. Default is `True`.

Yields:
A tuple of the tensor's name and the tensor itself.
"""
for name in self._tensor_names:
yield name, getattr(self, name)
if include_substructures:
for subname, substructure in self.named_substructures():
for name, tensor in substructure.named_tensors():
yield f"{name}.{subname}", tensor

def named_substructures(self) -> Iterator[Tuple[str, StructuredMatrix]]:
"""Yield all substructures and their names.

Yields:
A tuple of the substructure's name and the substructure itself.
"""
for name in self._substructure_names:
substructure = getattr(self, name)
for sub_name, tensor in substructure.named_tensors():
yield f"{name}.{sub_name}", tensor
yield name, getattr(self, name)

def frobenius_norm(self) -> Tensor:
"""Compute the Frobenius norm of the represented matrix.

Returns:
The Frobenius norm of the represented matrix.
"""
fro_squared = sum(
(t**2).sum() for _, t in self.named_tensors(include_substructures=False)
)
fro_squared_sub = sum(
s.frobenius_norm() ** 2 for _, s in self.named_substructures()
)
return (fro_squared + fro_squared_sub).sqrt()


class RecursiveTopRightMatrixTemplate(RecursiveStructuredMatrix):
Expand Down
17 changes: 17 additions & 0 deletions singd/structures/triltoeplitz.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
from torch import Tensor, arange, cat, zeros
from torch.linalg import vector_norm
from torch.nn.functional import conv1d, pad

from singd.structures.base import StructuredMatrix
Expand Down Expand Up @@ -191,6 +192,22 @@ def diag_add_(self, value: float) -> TrilToeplitzMatrix:
self._lower_diags[0].add_(value)
return self

def frobenius_norm(self) -> Tensor:
"""Compute the Frobenius norm of the represented matrix.

Returns:
The Frobenius norm of the represented matrix.
"""
(dim,) = self._lower_diags.shape
multiplicity = arange(
dim,
0,
step=-1,
dtype=self._lower_diags.dtype,
device=self._lower_diags.device,
)
return vector_norm(self._lower_diags * multiplicity.sqrt())

###############################################################################
# Special initialization operations #
###############################################################################
Expand Down
17 changes: 17 additions & 0 deletions singd/structures/triutoeplitz.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
from torch import Tensor, arange, cat, triu_indices, zeros
from torch.linalg import vector_norm
from torch.nn.functional import conv1d, pad

from singd.structures.base import StructuredMatrix
Expand Down Expand Up @@ -189,6 +190,22 @@ def diag_add_(self, value: float) -> TriuToeplitzMatrix:
self._upper_diags[0].add_(value)
return self

def frobenius_norm(self) -> Tensor:
"""Compute the Frobenius norm of the represented matrix.

Returns:
The Frobenius norm of the represented matrix.
"""
(dim,) = self._upper_diags.shape
multiplicity = arange(
dim,
0,
step=-1,
dtype=self._upper_diags.dtype,
device=self._upper_diags.device,
)
return vector_norm(self._upper_diags * multiplicity.sqrt())

###############################################################################
# Special initialization operations #
###############################################################################
Expand Down
18 changes: 17 additions & 1 deletion test/structures/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from matplotlib import pyplot as plt
from pytest import mark
from torch import Tensor, device, manual_seed, rand, zeros
from torch.linalg import vector_norm
from torch.linalg import matrix_norm, vector_norm

from singd.structures.base import StructuredMatrix
from singd.structures.utils import is_half_precision, supported_eye
Expand Down Expand Up @@ -585,6 +585,22 @@ def test_infinity_vector_norm(self, dev: device, dtype: torch.dtype):
structured = self.STRUCTURED_MATRIX_CLS.from_dense(sym_mat)
report_nonclose(truth, structured.infinity_vector_norm())

@mark.parametrize("dtype", DTYPES, ids=DTYPE_IDS)
@mark.parametrize("dev", DEVICES, ids=DEVICE_IDS)
def test_frobenius_norm(self, dev: device, dtype: torch.dtype):
"""Test Frobenius norm of a structured matrix.

Args:
dev: The device on which to run the test.
dtype: The data type of the matrices.
"""
for dim in self.DIMS:
manual_seed(0)
sym_mat = symmetrize(rand((dim, dim), device=dev, dtype=dtype))
truth = matrix_norm(self.project(sym_mat))
structured = self.STRUCTURED_MATRIX_CLS.from_dense(sym_mat)
report_nonclose(truth, structured.frobenius_norm())

@mark.expensive
def test_visual(self):
"""Create pictures and animations of the structure.
Expand Down
Loading