-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathFiniteDifferenceOp.py
122 lines (104 loc) · 3.52 KB
/
FiniteDifferenceOp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""Class for Finite Difference Operator."""
from collections.abc import Sequence
from typing import Literal
import torch
from mrpro.operators.LinearOperator import LinearOperator
from mrpro.utils.filters import filter_separable
class FiniteDifferenceOp(LinearOperator):
"""Finite Difference Operator."""
@staticmethod
def finite_difference_kernel(mode: str) -> torch.Tensor:
"""Finite difference kernel.
Parameters
----------
mode
String specifying kernel type
Returns
-------
Finite difference kernel
Raises
------
ValueError
If mode is not central, forward, backward or doublecentral
"""
if mode == 'central':
kernel = torch.tensor((-1, 0, 1)) / 2
elif mode == 'forward':
kernel = torch.tensor((0, -1, 1))
elif mode == 'backward':
kernel = torch.tensor((-1, 1, 0))
else:
raise ValueError(f'mode should be one of (central, forward, backward), not {mode}')
return kernel
def __init__(
self,
dim: Sequence[int],
mode: Literal['central', 'forward', 'backward'] = 'central',
pad_mode: Literal['zeros', 'circular'] = 'zeros',
) -> None:
"""Finite difference operator.
Parameters
----------
dim
Dimension along which finite differences are calculated.
mode
Type of finite difference operator
pad_mode
Padding to ensure output has the same size as the input
"""
super().__init__()
self.dim = dim
self.pad_mode: Literal['constant', 'circular'] = 'constant' if pad_mode == 'zeros' else pad_mode
self.register_buffer('kernel', self.finite_difference_kernel(mode))
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
"""Forward of finite differences.
Parameters
----------
x
Input tensor
Returns
-------
Finite differences of x along dim stacked along first dimension
"""
return (
torch.stack(
[
filter_separable(x, (self.kernel,), dim=(dim,), pad_mode=self.pad_mode, pad_value=0.0)
for dim in self.dim
]
),
)
def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]:
"""Adjoing of finite differences.
Parameters
----------
y
Finite differences stacked along first dimension
Returns
-------
Adjoint finite differences
Raises
------
ValueError
If the first dimension of y is to the same as the number of dimensions along which the finite differences
are calculated
"""
if y.shape[0] != len(self.dim):
raise ValueError('Fist dimension of input tensor has to match the number of finite difference directions.')
return (
torch.sum(
torch.stack(
[
filter_separable(
yi,
(torch.flip(self.kernel, dims=(-1,)),),
dim=(dim,),
pad_mode=self.pad_mode,
pad_value=0.0,
)
for dim, yi in zip(self.dim, y, strict=False)
]
),
dim=0,
),
)