-
Notifications
You must be signed in to change notification settings - Fork 2
/
huber_func.py
78 lines (56 loc) · 2.27 KB
/
huber_func.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#
"""Implementation of the Huber norm """
# Imports for common Python 2/3 codebase
from __future__ import print_function, division, absolute_import
from future import standard_library
standard_library.install_aliases()
from builtins import super
import numpy as np
from odl.solvers.functional.functional import Functional
from odl.operator import Operator
class HuberNorm(Functional):
"""The Huber functional"""
def __init__(self, space, epsilon):
"""Initialize a new instance.
Parameters
----------
space : `DiscreteLp` or `FnBase`
Domain of the functional.
epsilon : float
The parameter of the Huber functional.
"""
self.__epsilon = float(epsilon)
super().__init__(space=space, linear=False, grad_lipschitz=2)
@property
def epsilon(self):
"""The parameter of the Huber functional."""
return self.__epsilon
def _call(self, x):
"""Return the squared Huber norm of ``x``."""
indices = x.ufuncs.absolute().asarray() < self.epsilon
indices = np.float32(indices)
tmp = ((x * indices)**2 / (2.0 * self.epsilon) +
(x.ufuncs.absolute() - self.epsilon / 2.0) * (1-indices))
return tmp.inner(self.domain.one())
@property
def gradient(self):
"""Gradient operator of the functional."""
functional = self
class HuberNormGradient(Operator):
"""The gradient operator of this functional."""
def __init__(self):
"""Initialize a new instance."""
super().__init__(functional.domain, functional.domain,
linear=False)
# TODO: Update this call. Might not work for PorductSpaces
def _call(self, x):
"""Apply the gradient operator to the given point."""
indices = x.ufuncs.absolute().asarray() < functional.epsilon
indices = np.float32(indices)
tmp = ((x * indices) / (functional.epsilon) +
(x).ufuncs.sign() * (1-indices))
return tmp
return HuberNormGradient()
def __repr__(self):
"""Return ``repr(self)``."""
return '{}({!r})'.format(self.__class__.__name__, self.domain)