Skip to content

Commit

Permalink
Merge branch 'torch-2.2.0' into frobenius-norm
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed May 30, 2024
2 parents 308f7ce + 69a1250 commit 7232e61
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 228 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ setup_requires =
# Dependencies of the project (semicolon/line-separated):
install_requires =
numpy
torch
torch>=2.2.0
einops
einconv
# The usage of test_requires is discouraged, see `Dependency Management` docs
Expand Down
17 changes: 13 additions & 4 deletions singd/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,13 +591,16 @@ def _accumulate_H_terms(
kfac_approx = self._get_param_group_entry(module, "kfac_approx")
module_name = self.module_names[module]

# load inputs and gradients to the same precision as the pre-conditioner
(dtype_K, dtype_C), _ = self._get_preconditioner_dtypes_and_device(module)

# 1) PROCESS INPUTS AND GRAD_OUTPUTS
a = inputs[0].data
a = inputs[0].data.to(dtype_K)
# Process into matrix according to kfac_approx
# For convolutions, unfold the input, for modules with bias terms, append a 1
a = process_input(a, module, kfac_approx)

g = grad_output.data
g = grad_output.data.to(dtype_C)
# Process into matrix according to kfac_approx, add scaling from batch average
g = process_grad_output(g, module, loss_average, kfac_approx)

Expand Down Expand Up @@ -694,14 +697,20 @@ def _compute_natural_gradient(self, module: Module) -> Tuple[Tensor, ...]:
# 2) COMPUTE THE NATURAL GRADIENT IN CONCATENATED MATRIX FORM
module_name = self.module_names[module]

# load the gradient to the pre-conditioner precision while multiplying
dtype_K, dtype_C = self._get_param_group_entry(module, "preconditioner_dtype")

# We need to compute `W @ K @ K^T` where `W` is the weight gradient
# `K` supports `K @ ...` and `K^T @ ...`. Hence, we rewrite into
# `W @ K @ K^T = ( K @ (K^T @ W^T) )^T`.
K = self.Ks[module_name]
nat_grad = (K @ K.rmatmat(grad_mat.T)).T
nat_grad = (K @ K.rmatmat(grad_mat.T.to(dtype_K))).T

C = self.Cs[module_name]
nat_grad = C @ (C.rmatmat(nat_grad))
nat_grad = C @ (C.rmatmat(nat_grad.to(dtype_C)))

# load the pre-conditioned gradient back to the original precision
nat_grad = nat_grad.to(grad_mat.dtype)

# If DDP is used.
if dist.is_initialized():
Expand Down
12 changes: 6 additions & 6 deletions singd/structures/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch import Tensor, zeros
from torch.linalg import matrix_norm

from singd.structures.utils import diag_add_, supported_eye, supported_matmul
from singd.structures.utils import diag_add_, supported_eye


class StructuredMatrix(ABC):
Expand Down Expand Up @@ -100,10 +100,10 @@ def __matmul__(

dense = self.to_dense()
if isinstance(other, Tensor):
return supported_matmul(dense, other)
return dense @ other

other_dense = other.to_dense()
return self.from_dense(supported_matmul(dense, other_dense))
return self.from_dense(dense @ other_dense)

@classmethod
@abstractmethod
Expand Down Expand Up @@ -206,7 +206,7 @@ def rmatmat(self, mat: Tensor) -> Tensor:
A dense PyTorch tensor resulting from the multiplication.
"""
self._warn_naive_implementation("rmatmat")
return supported_matmul(self.to_dense().T, mat)
return self.to_dense().T @ mat

@classmethod
def _warn_naive_implementation(cls, fn_name: str):
Expand Down Expand Up @@ -281,7 +281,7 @@ def from_inner(self, X: Union[Tensor, None] = None) -> StructuredMatrix:
"""
self._warn_naive_implementation("from_inner")
S_dense = self.to_dense().T if X is None else self.rmatmat(X)
return self.from_dense(supported_matmul(S_dense, S_dense.T))
return self.from_dense(S_dense @ S_dense.T)

# NOTE This operation should be removed long-term as implementing IF-KFAC
# with `from_inner` is more efficient. For now, it will exist as it makes
Expand All @@ -298,7 +298,7 @@ def from_inner2(self, XXT: Tensor) -> StructuredMatrix:
"""
self._warn_naive_implementation("from_inner2")
dense = self.to_dense()
return self.from_dense(supported_matmul(dense.T, XXT, dense))
return self.from_dense(dense.T @ XXT @ dense)

def average_trace(self) -> Tensor:
"""Compute the average trace of the represented matrix.
Expand Down
28 changes: 10 additions & 18 deletions singd/structures/blockdiagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,10 @@

import torch
from einops import rearrange
from torch import Tensor, arange, cat, zeros
from torch import Tensor, arange, cat, einsum, zeros

from singd.structures.base import StructuredMatrix
from singd.structures.utils import (
lowest_precision,
supported_einsum,
supported_eye,
supported_matmul,
)
from singd.structures.utils import lowest_precision, supported_eye


class BlockDiagonalMatrixTemplate(StructuredMatrix):
Expand Down Expand Up @@ -194,7 +189,7 @@ def __matmul__(
other_blocks = rearrange(
other_blocks, "(block row) col -> block row col", **dims
)
result_blocks = supported_einsum(
result_blocks = einsum(
"nij,njk->nik",
self._blocks.to(compute_dtype),
other_blocks.to(compute_dtype),
Expand All @@ -205,15 +200,15 @@ def __matmul__(

out_dtype = self._last.dtype
compute_dtype = lowest_precision(self._last.dtype, other_last.dtype)
result_last = supported_matmul(
self._last.to(compute_dtype), other_last.to(compute_dtype)
result_last = (
self._last.to(compute_dtype) @ other_last.to(compute_dtype)
).to(out_dtype)

return cat([result_blocks, result_last])

else:
out_blocks = supported_einsum("nij,njk->nik", self._blocks, other._blocks)
out_last = supported_matmul(self._last, other._last)
out_blocks = einsum("nij,njk->nik", self._blocks, other._blocks)
out_last = self._last @ other._last
return self.__class__(out_blocks, out_last)

def __add__(
Expand Down Expand Up @@ -286,8 +281,8 @@ def from_inner(self, X: Union[Tensor, None] = None) -> BlockDiagonalMatrixTempla
dims = {"block": num_blocks, "row": self.BLOCK_DIM}
S_blocks = rearrange(S_blocks, "(block row) col -> block row col", **dims)

out_blocks = supported_einsum("nij,nkj->nik", S_blocks, S_blocks)
out_last = supported_matmul(S_last, S_last.T)
out_blocks = einsum("nij,nkj->nik", S_blocks, S_blocks)
out_last = S_last @ S_last.T

return self.__class__(out_blocks, out_last)

Expand All @@ -299,10 +294,7 @@ def average_trace(self) -> Tensor:
"""
num_blocks, last_dim = self._blocks.shape[0], self._last.shape[0]
dim = num_blocks * self.BLOCK_DIM + last_dim
return (
supported_einsum("nii->", self._blocks / dim)
+ (self._last.diag() / dim).sum()
)
return einsum("nii->", self._blocks / dim) + (self._last.diag() / dim).sum()

def diag_add_(self, value: float) -> BlockDiagonalMatrixTemplate:
"""In-place add a value to the diagonal of the represented matrix.
Expand Down
71 changes: 25 additions & 46 deletions singd/structures/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,10 @@
from typing import Tuple, Union

import torch
from torch import Tensor, arange, cat, ones, zeros
from torch import Tensor, arange, cat, einsum, ones, zeros

from singd.structures.base import StructuredMatrix
from singd.structures.utils import (
diag_add_,
lowest_precision,
supported_einsum,
supported_eye,
supported_matmul,
)
from singd.structures.utils import diag_add_, lowest_precision, supported_eye


class HierarchicalMatrixTemplate(StructuredMatrix):
Expand Down Expand Up @@ -209,34 +203,25 @@ def __matmul__(
[self.K1, self.diag_dim, self.K2]
)

top = (
supported_matmul(self.A, other_top)
+ supported_matmul(B_C, other_middle)
+ supported_matmul(B_E, other_bottom)
)
middle = supported_einsum("i,ij->ij", self.C, other_middle)
bottom = supported_matmul(self.D, other_middle) + supported_matmul(
self.E, other_bottom
)
top = self.A @ other_top + B_C @ other_middle + B_E @ other_bottom
middle = einsum("i,ij->ij", self.C, other_middle)
bottom = self.D @ other_middle + self.E @ other_bottom

return cat([top, middle, bottom], dim=0)

else:
A_new = supported_matmul(self.A, other.A)
A_new = self.A @ other.A
C_new = self.C * other.C
E_new = supported_matmul(self.E, other.E)
D_new = supported_einsum("ij,j->ij", self.D, other.C) + supported_matmul(
self.E, other.D
)
E_new = self.E @ other.E
D_new = einsum("ij,j->ij", self.D, other.C) + self.E @ other.D

B_C_other, B_E_other = other.B.split([other.diag_dim, other.K2], dim=1)
B_new = cat(
[
supported_matmul(self.A, B_C_other)
+ supported_einsum("ij,j->ij", B_C, other.C)
+ supported_matmul(B_E, other.D),
supported_matmul(self.A, B_E_other)
+ supported_matmul(B_E, other.E),
self.A @ B_C_other
+ einsum("ij,j->ij", B_C, other.C)
+ B_E @ other.D,
self.A @ B_E_other + B_E @ other.E,
],
dim=1,
)
Expand Down Expand Up @@ -291,20 +276,18 @@ def rmatmat(self, mat: Tensor) -> Tensor:
# parts of B that share columns with C, E
B_C, B_E = self.B.split([self.diag_dim, self.K2], dim=1)

top = supported_matmul(self.A.T, mat_top)
top = self.A.T @ mat_top

compute_dtype = lowest_precision(self.C.dtype, mat_middle.dtype)
out_dtype = self.C.dtype
middle = (
supported_matmul(B_C.T, mat_top)
+ supported_einsum(
B_C.T @ mat_top
+ einsum(
"i,ij->ij", self.C.to(compute_dtype), mat_middle.to(compute_dtype)
).to(out_dtype)
+ supported_matmul(self.D.T, mat_bottom)
)
bottom = supported_matmul(B_E.T, mat_top) + supported_matmul(
self.E.T, mat_bottom
+ self.D.T @ mat_bottom
)
bottom = B_E.T @ mat_top + self.E.T @ mat_bottom

return cat([top, middle, bottom])

Expand All @@ -323,26 +306,22 @@ def from_inner(self, X: Union[Tensor, None] = None) -> HierarchicalMatrixTemplat
`self.T @ X @ X^T @ self`.
"""
if X is None:
A_new = supported_matmul(self.A.T, self.A)
B_new = 2 * supported_matmul(self.A.T, self.B)
A_new = self.A.T @ self.A
B_new = 2 * self.A.T @ self.B

# parts of B that share columns with C, E
B_C, B_E = self.B.split([self.diag_dim, self.K2], dim=1)

C_new = self.C**2 + (B_C**2).sum(0) + (self.D**2).sum(0)
D_new = 2 * (
supported_matmul(B_E.T, B_C) + supported_matmul(self.E.T, self.D)
)
E_new = supported_matmul(self.E.T, self.E) + supported_matmul(B_E.T, B_E)
D_new = 2 * (B_E.T @ B_C + self.E.T @ self.D)
E_new = self.E.T @ self.E + B_E.T @ B_E
else:
S_A, S_C, S_E = self.rmatmat(X).split([self.K1, self.diag_dim, self.K2])
A_new = supported_matmul(S_A, S_A.T)
B_new = 2 * cat(
[supported_matmul(S_A, S_C.T), supported_matmul(S_A, S_E.T)], dim=1
)
A_new = S_A @ S_A.T
B_new = 2 * cat([S_A @ S_C.T, S_A @ S_E.T], dim=1)
C_new = (S_C**2).sum(1)
D_new = 2 * supported_matmul(S_E, S_C.T)
E_new = supported_matmul(S_E, S_E.T)
D_new = 2 * S_E @ S_C.T
E_new = S_E @ S_E.T

return self.__class__(A_new, B_new, C_new, D_new, E_new)

Expand Down
11 changes: 3 additions & 8 deletions singd/structures/triltoeplitz.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,10 @@

import torch
from torch import Tensor, arange, cat, zeros
from torch.nn.functional import pad
from torch.nn.functional import conv1d, pad

from singd.structures.base import StructuredMatrix
from singd.structures.utils import (
all_traces,
lowest_precision,
supported_conv1d,
toeplitz_matmul,
)
from singd.structures.utils import all_traces, lowest_precision, toeplitz_matmul


class TrilToeplitzMatrix(StructuredMatrix):
Expand Down Expand Up @@ -149,7 +144,7 @@ def __matmul__(
# need to create fake channel dimensions
conv_input = pad(other._lower_diags, (dim - 1, 0)).unsqueeze(0)
conv_weight = col.flip(0).unsqueeze(0).unsqueeze(0)
mat_column = supported_conv1d(conv_input, conv_weight).squeeze(0)
mat_column = conv1d(conv_input, conv_weight).squeeze(0)
return TrilToeplitzMatrix(mat_column)

def rmatmat(self, mat: Tensor) -> Tensor:
Expand Down
11 changes: 3 additions & 8 deletions singd/structures/triutoeplitz.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,10 @@

import torch
from torch import Tensor, arange, cat, triu_indices, zeros
from torch.nn.functional import pad
from torch.nn.functional import conv1d, pad

from singd.structures.base import StructuredMatrix
from singd.structures.utils import (
all_traces,
lowest_precision,
supported_conv1d,
toeplitz_matmul,
)
from singd.structures.utils import all_traces, lowest_precision, toeplitz_matmul


class TriuToeplitzMatrix(StructuredMatrix):
Expand Down Expand Up @@ -147,7 +142,7 @@ def __matmul__(
# need to create fake channel dimensions
conv_input = pad(other._upper_diags, (dim - 1, 0)).unsqueeze(0)
conv_weight = row.flip(0).unsqueeze(0).unsqueeze(0)
mat_row = supported_conv1d(conv_input, conv_weight).squeeze(0)
mat_row = conv1d(conv_input, conv_weight).squeeze(0)
return TriuToeplitzMatrix(mat_row)

def rmatmat(self, mat: Tensor) -> Tensor:
Expand Down
Loading

0 comments on commit 7232e61

Please sign in to comment.