-
Notifications
You must be signed in to change notification settings - Fork 3
/
ProximableFunctionalSeparableSum.py
118 lines (95 loc) · 3.92 KB
/
ProximableFunctionalSeparableSum.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""Separable Sum of Proximable Functionals."""
from __future__ import annotations
import operator
from collections.abc import Iterator
from functools import reduce
from typing import cast
import torch
from typing_extensions import Self, Unpack
from mrpro.operators.Functional import ProximableFunctional
from mrpro.operators.Operator import Operator
class ProximableFunctionalSeparableSum(Operator[Unpack[tuple[torch.Tensor, ...]], tuple[torch.Tensor]]):
r"""Separabke Sum of Proximable Functionals.
This is a separable sum of the functionals. The forward method returns the sum of the functionals
evaluated at the inputs, :math:`\sum_i f_i(x_i)`.
"""
functionals: tuple[ProximableFunctional, ...]
def __init__(self, *functionals: ProximableFunctional) -> None:
"""Initialize the separable sum of proximable functionals.
Parameters
----------
functionals
The proximable functionals to be summed.
"""
super().__init__()
self.functionals = functionals
def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor]:
"""Apply the functionals to the inputs.
Parameters
----------
x
The inputs to the functionals
Returns
-------
The sum of the functionals applied to the inputs
"""
if len(x) != len(self.functionals):
raise ValueError('The number of inputs must match the number of functionals.')
result = reduce(operator.add, (f(xi)[0] for f, xi in zip(self.functionals, x, strict=True)))
return (result,)
def prox(self, *x: torch.Tensor, sigma: float | torch.Tensor = 1) -> tuple[torch.Tensor, ...]:
"""Apply the proximal operators of the functionals to the inputs.
Parameters
----------
x
The inputs to the proximal operators
sigma
The scaling factor for the proximal operators
Returns
-------
A tuple of the proximal operators applied to the inputs
"""
prox_x = tuple(
f.prox(xi, sigma)[0] for f, xi in zip(self.functionals, cast(tuple[torch.Tensor, ...], x), strict=True)
)
return prox_x
def prox_convex_conj(self, *x: torch.Tensor, sigma: float | torch.Tensor = 1) -> tuple[torch.Tensor, ...]:
"""Apply the proximal operators of the convex conjugate of the functionals to the inputs.
Parameters
----------
x
The inputs to the proximal operators
sigma
The scaling factor for the proximal operators
Returns
-------
A tuple of the proximal convex conjugate operators applied to the inputs
"""
prox_convex_conj_x = tuple(
f.prox_convex_conj(xi, sigma)[0]
for f, xi in zip(self.functionals, cast(tuple[torch.Tensor, ...], x), strict=True)
)
return prox_convex_conj_x
def __or__(
self,
other: ProximableFunctional | ProximableFunctionalSeparableSum,
) -> Self:
"""Separable sum functionals."""
if isinstance(other, ProximableFunctionalSeparableSum):
return self.__class__(*self.functionals, *other.functionals)
elif isinstance(other, ProximableFunctional):
return self.__class__(*self.functionals, other)
else:
return NotImplemented # type: ignore[unreachable]
def __ror__(self, other: ProximableFunctional) -> Self:
"""Separable sum functionals."""
if isinstance(other, ProximableFunctional):
return self.__class__(other, *self.functionals)
else:
return NotImplemented # type: ignore[unreachable]
def __iter__(self) -> Iterator[ProximableFunctional]:
"""Iterate over the functionals."""
return iter(self.functionals)
def __len__(self) -> int:
"""Return the number of functionals."""
return len(self.functionals)