-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprediction_module.py
62 lines (48 loc) · 2.69 KB
/
prediction_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import torch.nn as nn
class PredictionModule(nn.Module):
def __init__(self, classifier, imputation = None, imputation_reg = None, input_size = None,**kwargs):
super(PredictionModule, self).__init__()
self.classifier = classifier
self.imputation = imputation
self.input_size = input_size
def get_imputation(self, data, mask, index = None):
""" Get the imputation of the data given the mask
Parameters:
-----------
data : torch.Tensor of shape (nb_sample_z_MC * nb_sample_z_iwae * batch_size, channels, size_lists...)
The data to be imputed
mask : torch.Tensor of shape (nb_sample_z_MC * nb_sample_z_iwae * batch_size, channels, size_lists...)
The mask to be used for the imputation, shoudl be in the same shape as the data
index : torch.Tensor of shape (nb_sample_z_MC * nb_sample_z_iwae * batch_size, )
The index to be used for imputation
"""
data = data.reshape(mask.shape) # Quick fix when the reshape function do not match the shape of the data (change the dataset might be better), TODO
x_imputed, _ = self.imputation(data, mask, index)
return x_imputed
def __call__(self, data, mask = None, index = None):
""" Using the data and the mask, do the imputation and classification
Parameters:
-----------
data : torch.Tensor of shape (nb_sample_z_MC * nb_sample_z_iwae * batch_size, channels, size_lists...)
The data to be classified
mask : torch.Tensor of shape (nb_sample_z_MC * nb_sample_z_iwae * batch_size, channels, size_lists...)
The mask to be used for the classification, shoudl be in the same shape as the data
index : torch.Tensor of shape (nb_sample_z_MC * nb_sample_z_iwae * batch_size, )
The index to be used for imputation
Returns:
--------
y_hat : torch.Tensor of shape (nb_imputation * nb_sample_z_MC * nb_sample_z_iwae * batch_size, nb_category)
The output of the classification
loss_reconstruction : torch.Tensor of shape (1)
Some regularization term that can be added to the loss (for instance in the case of version Autoencoder regularisation)
"""
if mask is not None :
mask = mask.reshape(data.shape)
if self.imputation is not None and mask is not None :
x_imputed, loss_reconstruction = self.imputation(data, mask, index)
y_hat = self.classifier(x_imputed)
else :
y_hat = self.classifier(data)
loss_reconstruction = torch.zeros((1), device = data.device)
return y_hat, loss_reconstruction