-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathmask.py
31 lines (25 loc) · 987 Bytes
/
mask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from src import register
from src.typing import LossData
__all__ = ['MaskReg']
@register('disp_mask')
class MaskReg(nn.Module):
"""Class implementing photometric loss masking regularization.
From SfM-Learner (https://arxiv.org/abs/1704.07813)
Based on the `explainability` mask, which predicts a weighting factor for each pixel in the photometric loss.
To avoid the degenerate solution where all pixels are ignored, this regularization pushes all values towards 1
using binary cross-entropy.
"""
def forward(self, x: Tensor) -> LossData:
"""Mask regularization forward pass.
:param x: (Tensor) (*) Input sigmoid explainability mask.
:return: {
loss: (Tensor) (,) Computed loss.
loss_dict: (TensorDict) {}.
}
"""
loss = F.binary_cross_entropy(x, torch.ones_like(x))
return loss, {}