-
Notifications
You must be signed in to change notification settings - Fork 2
/
flatten.py
28 lines (21 loc) · 830 Bytes
/
flatten.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
"""Curvature-vector products of batch-wise flatten operation."""
from backpack.core.layers import Flatten
from ..hbp.module import hbp_decorate
class CVPFlatten(hbp_decorate(Flatten)):
"""Flatten all dimensions except batch dimension, with CVP supoprt."""
@classmethod
def from_torch(cls, torch_layer):
if not isinstance(torch_layer, Flatten):
raise ValueError(
"Expecting backpack.core.layers.Flatten, got {}".format(
torch_layer.__class__
)
)
return cls()
# override
def hbp_hooks(self):
"""No hooks required."""
# override
def input_hessian(self, output_hessian, modify_2nd_order_terms="none"):
"""Pass on the Hessian with respect to the layer input."""
return output_hessian