Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NanMixture: Distribution to model missing values #913

Merged
merged 27 commits into from
Jul 28, 2020
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
87a75fb
added a deterministic/degenerate distribution
PascalIversen Jun 24, 2020
19ae387
Corrected the formula of the Standard Deviation of the Mixture Distri…
PascalIversen Jun 24, 2020
54eddd3
Added a distribution to model missing values
PascalIversen Jun 28, 2020
fe69afa
added test script to test log_prob and the gradients and fixed some e…
PascalIversen Jul 1, 2020
880e78f
fixed bug in test script
PascalIversen Jul 1, 2020
1b53501
Merge branch 'master' into distribution_missing_values
PascalIversen Jul 6, 2020
72ff878
corrected true gradients in the test file and edge cases of the log_p…
PascalIversen Jul 6, 2020
188e136
fixed edge cases of the gradients
PascalIversen Jul 6, 2020
1f99be3
bugfix CategoricalOutput
PascalIversen Jul 8, 2020
9ede612
test skip
PascalIversen Jul 8, 2020
953c255
style fix
PascalIversen Jul 8, 2020
438d275
addressed PR issues
PascalIversen Jul 13, 2020
48bed05
skipping tests which take too long and rearranging imports
PascalIversen Jul 13, 2020
576c392
fixed the output args issue of a NanMixture with a Categorical distri…
PascalIversen Jul 13, 2020
283328b
lowered assertion tolerances
PascalIversen Jul 20, 2020
aec69ef
lowered assertion tolerances
PascalIversen Jul 20, 2020
03a4d0f
refractoring of the NanMixture tests
PascalIversen Jul 20, 2020
6ba3eed
refractoring of the NanMixture tests
PascalIversen Jul 20, 2020
205bc58
Merge branch 'master' into distribution_missing_values
PascalIversen Jul 21, 2020
1b557bb
added NanMixture support to the SimpleFeedForward and fixed typing issue
PascalIversen Jul 21, 2020
77578e7
refractoring
PascalIversen Jul 21, 2020
52f5f45
removing SimpleFeedForward changes
PascalIversen Jul 23, 2020
e6ec622
Merge branch 'master' into distribution_missing_values
PascalIversen Jul 27, 2020
100921c
increased sample size for mixture stddev and mean tests to prevent fa…
PascalIversen Jul 27, 2020
8de41cb
added random seeds to tests
PascalIversen Jul 27, 2020
c3f82a7
increased tol
PascalIversen Jul 27, 2020
0f57265
Merge branch 'master' into distribution_missing_values
lostella Jul 28, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions src/gluonts/model/simple_feedforward/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

# First-party imports
from gluonts.mx.block.scaler import MeanScaler, NOPScaler
from gluonts.mx.distribution import DistributionOutput
from gluonts.mx.distribution import DistributionOutput, NanMixture
from gluonts.support.util import weighted_average


Expand Down Expand Up @@ -158,15 +158,24 @@ def hybrid_forward(
distr_args, loc=loc, scale=scale
)

# (batch_size, prediction_length, target_dim)
loss = distr.loss(future_target)
if isinstance(distr, NanMixture):
# demasking the missing values for the future_targets
PascalIversen marked this conversation as resolved.
Show resolved Hide resolved

weighted_loss = weighted_average(
F=F, x=loss, weights=future_observed_values, axis=1
)

# (batch_size, )
return weighted_loss
loss = distr.loss(
F.where(
future_observed_values,
future_target,
0.0 / future_target.zeros_like(),
)
)
return loss
else:
loss = distr.loss(future_target)
weighted_loss = weighted_average(
F=F, x=loss, weights=future_observed_values, axis=1
)
# (batch_size, )
return weighted_loss


class SimpleFeedForwardSamplingNetwork(SimpleFeedForwardNetworkBase):
Expand Down
6 changes: 6 additions & 0 deletions src/gluonts/mx/distribution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
InverseBoxCoxTransformOutput,
)
from .categorical import Categorical, CategoricalOutput
from .deterministic import Deterministic, DeterministicOutput
from .dirichlet import Dirichlet, DirichletOutput
from .dirichlet_multinomial import (
DirichletMultinomial,
Expand All @@ -40,6 +41,7 @@
MultivariateGaussian,
MultivariateGaussianOutput,
)
from .nan_mixture import NanMixture, NanMixtureOutput
from .neg_binomial import NegativeBinomial, NegativeBinomialOutput
from .piecewise_linear import (
PiecewiseLinear,
Expand Down Expand Up @@ -71,6 +73,8 @@
"LowrankMultivariateGaussianOutput",
"MixtureDistributionOutput",
"MixtureDistribution",
"NanMixture",
"NanMixtureOutput",
"NegativeBinomialOutput",
"NegativeBinomial",
"UniformOutput",
Expand All @@ -95,6 +99,8 @@
"CategoricalOutput",
"LogitNormal",
"LogitNormalOutput",
"Deterministic",
"DeterministicOutput",
]

# fix Sphinx issues, see https://bit.ly/2K2eptM
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/mx/distribution/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def s(bin_probs):
return indices

return _sample_multiple(s, self.probs, num_samples=num_samples).astype(
"int32"
dtype
)

@property
Expand Down
152 changes: 152 additions & 0 deletions src/gluonts/mx/distribution/deterministic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

# Standard library imports
import math
from functools import partial
from typing import Dict, List, Optional, Tuple

# Third-party imports
import numpy as np

# First-party imports
from gluonts.model.common import Tensor
from gluonts.core.component import validated
from gluonts.support.util import erf, erfinv

# Relative imports
from .distribution import Distribution, _sample_multiple, getF, softplus
from .distribution_output import DistributionOutput


class Deterministic(Distribution):
r"""
Deterministic/Degenerate distribution.

Parameters
----------
value
Tensor containing the values, of shape `(*batch_shape, *event_shape)`.
F
"""

is_reparameterizable = True

@validated()
def __init__(self, value: Tensor) -> None:
self.value = value
self.F = getF(value)

@property
def batch_shape(self) -> Tuple:
return self.value.shape

@property
def event_shape(self) -> Tuple:
return ()

@property
def event_dim(self) -> int:
return 0

def log_prob(self, x: Tensor) -> Tensor:
F = self.F
value = self.value
is_both_nan = F.broadcast_logical_and(x != x, value != value)
is_equal_or_both_nan = F.broadcast_logical_or(
(x == value), is_both_nan
)
return F.log(is_equal_or_both_nan)

@property
def mean(self) -> Tensor:
return self.value

@property
def stddev(self) -> Tensor:
return self.value.zeros_like()

def cdf(self, x):
F = self.F
value = self.value
is_both_nan = F.broadcast_logical_and(
F.contrib.isnan(x), F.contrib.isnan(value)
)
is_greater_equal_or_both_nan = F.broadcast_logical_or(
(x >= value), is_both_nan
)
return is_greater_equal_or_both_nan

def sample(
self, num_samples: Optional[int] = None, dtype=np.int32
) -> Tensor:
return _sample_multiple(
lambda value: value, value=self.value, num_samples=num_samples
).astype(dtype=dtype)

def quantile(self, level: Tensor) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a deterministic distribution, shouldn't all quantiles be the same, i.e., the value of the distribution?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly, but there are some edge cases: The quantile of p=0 is for example -inf. Furthermore, if the deterministic distribution is NaN-valued the quantile is always NaN.

F = self.F
# we consider level to be an independent axis and so expand it
# to shape (num_levels, 1, 1, ...)

for _ in range(self.all_dim):
level = level.expand_dims(axis=-1)

quantiles = F.broadcast_mul(self.value, level.ones_like())
level = F.broadcast_mul(quantiles.ones_like(), level)

minus_inf = -quantiles.ones_like() / 0.0
quantiles = F.where(
F.broadcast_logical_or(level != 0, F.contrib.isnan(quantiles)),
quantiles,
minus_inf,
)

nans = level.zeros_like() / 0.0
quantiles = F.where(level != level, nans, quantiles)

return quantiles

@property
def args(self) -> List:
return [self.value]


class DeterministicOutput(DistributionOutput):
args_dim: Dict[str, int] = {"value": 1}
distr_cls: type = Deterministic

@classmethod
def domain_map(cls, F, value):
r"""
Maps raw tensors to valid arguments for constructing a Gaussian
distribution.

Parameters
----------
F
value
Tensor of shape `(*batch_shape, 1)`

Returns
-------
Tuple[Tensor, Tensor]
Two squeezed tensors, of shape `(*batch_shape)`: the first has the
same entries as `mu` and the second has entries mapped to the
positive orthant.
"""
return value.squeeze(axis=-1)

@property
def event_shape(self) -> Tuple:
return ()
2 changes: 1 addition & 1 deletion src/gluonts/mx/distribution/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
# Third-party imports
import numpy as np

from gluonts.core.component import validated

# First-party imports
from gluonts.model.common import Tensor
from gluonts.support.util import erf, erfinv
from gluonts.core.component import validated

# Relative imports
from .distribution import Distribution, _sample_multiple, getF, softplus
Expand Down
Loading