-
Notifications
You must be signed in to change notification settings - Fork 60
/
Copy pathquantile_loss.py
executable file
·130 lines (109 loc) · 5.06 KB
/
quantile_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
class QuantileLossCalculator():
"""Computes the combined quantile loss for prespecified quantiles.
Attributes:
quantiles: Quantiles to compute losses
"""
def __init__(self, quantiles, output_size):
"""Initializes computer with quantiles for loss calculations.
Args:
quantiles: Quantiles to use for computations.
"""
self.quantiles = quantiles
self.output_size = output_size
# Loss functions.
def quantile_loss(self, y, y_pred, quantile):
""" Computes quantile loss for pytorch.
Standard quantile loss as defined in the "Training Procedure" section of
the main TFT paper
Args:
y: Targets
y_pred: Predictions
quantile: Quantile to use for loss calculations (between 0 & 1)
Returns:
Tensor for quantile loss.
"""
# Checks quantile
if quantile < 0 or quantile > 1:
raise ValueError(
'Illegal quantile value={}! Values should be between 0 and 1.'.format(quantile))
prediction_underflow = y - y_pred
# print('prediction_underflow')
# print(prediction_underflow.shape)
q_loss = quantile * torch.max(prediction_underflow, torch.zeros_like(prediction_underflow)) + \
(1. - quantile) * torch.max(-prediction_underflow, torch.zeros_like(prediction_underflow))
# print('q_loss')
# print(q_loss.shape)
# loss = torch.mean(q_loss, dim = 1)
# print('loss')
# print(loss.shape)
# return loss
# return torch.sum(q_loss, dim = -1)
return q_loss.unsqueeze(1)
def apply(self, b, a):
"""Returns quantile loss for specified quantiles.
Args:
a: Targets
b: Predictions
"""
quantiles_used = set(self.quantiles)
loss = []
# loss = 0.
for i, quantile in enumerate(self.quantiles):
if quantile in quantiles_used:
#print(a[Ellipsis, self.output_size * i:self.output_size * (i + 1)].shape)
# loss += self.quantile_loss(a[Ellipsis, self.output_size * i:self.output_size * (i + 1)],
# b[Ellipsis, self.output_size * i:self.output_size * (i + 1)],
# quantile)
#print(a[Ellipsis, self.output_size * i].shape)
#loss += self.quantile_loss(a[Ellipsis, self.output_size * i],
# b[Ellipsis, self.output_size * i],
# quantile)
# loss.append(self.quantile_loss(a[Ellipsis, self.output_size * i:self.output_size * (i + 1)],
# b[Ellipsis, self.output_size * i:self.output_size * (i + 1)],
# quantile))
loss.append(self.quantile_loss(a[Ellipsis, i],
b[Ellipsis, i],
quantile))
# loss_computed = torch.cat(loss, axis = -1)
# loss_computed = torch.sum(loss_computed, axis = -1)
# loss_computed = torch.sum(loss_computed, axis = 0)
loss_computed = torch.mean(torch.sum(torch.cat(loss, axis = 1), axis = 1))
return loss_computed
# return loss
class NormalizedQuantileLossCalculator():
"""Computes the combined quantile loss for prespecified quantiles.
Attributes:
quantiles: Quantiles to compute losses
"""
def __init__(self, quantiles, output_size):
"""Initializes computer with quantiles for loss calculations.
Args:
quantiles: Quantiles to use for computations.
"""
self.quantiles = quantiles
self.output_size = output_size
# Loss functions.
def apply(self, y, y_pred, quantile):
""" Computes quantile loss for pytorch.
Standard quantile loss as defined in the "Training Procedure" section of
the main TFT paper
Args:
y: Targets
y_pred: Predictions
quantile: Quantile to use for loss calculations (between 0 & 1)
Returns:
Tensor for quantile loss.
"""
# Checks quantile
if quantile < 0 or quantile > 1:
raise ValueError(
'Illegal quantile value={}! Values should be between 0 and 1.'.format(quantile))
prediction_underflow = y - y_pred
# print('prediction_underflow')
# print(prediction_underflow.shape)
weighted_errors = quantile * torch.max(prediction_underflow, torch.zeros_like(prediction_underflow)) + \
(1. - quantile) * torch.max(-prediction_underflow, torch.zeros_like(prediction_underflow))
quantile_loss = torch.mean(weighted_errors)
normaliser = torch.mean(torch.abs(quantile_loss))
return 2 * quantile_loss / normaliser