Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NanMixture: Distribution to model missing values #913

Merged
merged 27 commits into from
Jul 28, 2020

Conversation

PascalIversen
Copy link
Contributor

Description of changes:
Added a distribution to model data which contains missing values.

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

Copy link
Contributor

@benidis benidis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Pascal, looks great overall. I see there are some TODOs here and there... are these things you are working on including now?

src/gluonts/mx/distribution/deterministic.py Outdated Show resolved Hide resolved
lambda value: value, value=self.value, num_samples=num_samples
)

def quantile(self, level: Tensor) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a deterministic distribution, shouldn't all quantiles be the same, i.e., the value of the distribution?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly, but there are some edge cases: The quantile of p=0 is for example -inf. Furthermore, if the deterministic distribution is NaN-valued the quantile is always NaN.

src/gluonts/mx/distribution/nan_mixture.py Outdated Show resolved Hide resolved
@PascalIversen
Copy link
Contributor Author

Thanks Pascal, looks great overall. I see there are some TODOs here and there... are these things you are working on including now?

Thank you for reviewing, Konstantinos.
I forgot to remove some of these TODO's. The only thing left is concerning the testing of the NanMixture. Since we haven't decided if the mean/stddev should return NaN or the values of the Non-NaN component, I would leave it for now.

test/distribution/test_nan_mixture.py Outdated Show resolved Hide resolved
test/distribution/test_nan_mixture.py Outdated Show resolved Hide resolved

loss_value = loss.mean().asnumpy()
t.set_postfix({"loss": loss_value})
trainer.step(BATCH_SIZE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just an idea (and we could do this in a separate PR): we know that EM works better than SGD for the job, would it make sense to use EM in this type of tests? As long as this would involve invoking the loss (and its backward) then it should be fine. In that case we could probably hope for “stricter” assertions on the recovered parameters below

Copy link
Contributor Author

@PascalIversen PascalIversen Jul 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do that for the Mixture distribution. However, for the NanMixture the maximisation step is not defined. Also, EM does not use the gradient. Maybe in a separate PR I could implement something like "A Gradient Algorithm Locally Equivalent to the EM Algorithm" which uses the gradient and would, I think, work for the NanMixture.
We can also achieve stricter assertions by using a larger sample size, which however makes the test slower.

test/distribution/test_nan_mixture.py Outdated Show resolved Hide resolved
@PascalIversen
Copy link
Contributor Author

I created some artificial time series data with missing values and fitted a Gaussian NanMixture using the SimpleFeedForward estimator.
I had to make one more adjustment to the SimpleFeedForward's hybrid_forward() and I am not sure if it should be part of this PR. Namely, I have to demask the NaN values before the loss calculation:

    if isinstance(distr, NanMixture):
        # demasking the missing values for the future_targets

        loss = distr.loss(
            F.where(
                future_observed_values,
                future_target,
                0.0 / future_target.zeros_like(),
            )
        )
        return loss
    else:
        loss = distr.loss(future_target)
        weighted_loss = weighted_average(
            F=F, x=loss, weights=future_observed_values, axis=1
        )
        # (batch_size, )
        return weighted_loss

In my experiments, the parameters are recovered:

This was done using this code:


# Third-party imports
get_ipython().run_line_magic('matplotlib', 'inline')
import mxnet as mx
from mxnet import gluon
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


from gluonts.dataset.repository.datasets import get_dataset, dataset_recipes
from gluonts.dataset.util import to_pandas
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.common import ListDataset



def nan_prob_gen():
    while(True):
        for i in np.linspace(0.01, 0.3, 12):
            yield i
        for i in np.linspace(0.3, 0.01, 12):
            yield i




def create_dataset(num_series, num_steps, period=24, mu=1, sigma=0.3, true_mean=False):
    # create target: noise + pattern    
    # noise
    noise = np.random.normal(mu, sigma, size=(num_series, num_steps))
    
    # pattern - sinusoid with different phase
    sin_minumPi_Pi = np.sin(np.tile(np.linspace(-np.pi, np.pi, period), int(num_steps / period)))
    sin_Zero_2Pi = np.sin(np.tile(np.linspace(0, 2 * np.pi, period), int(num_steps / period)))
    
    target = np.tile(sin_minumPi_Pi.reshape(1, -1), 
                                      (int(np.ceil(num_series)),1))

    if(not true_mean):
        nan_gen = nan_prob_gen()
        nan_prob = []
        for _ in range(num_steps):
            nan_prob.append(next(nan_gen))

        is_nan = np.random.uniform(size = target.shape)<nan_prob
        target = np.where(is_nan, np.nan, target)
        
        target = noise + target
    
        
    else:
        target += mu
    return target


# define the parameters of the dataset
n_series = 1200
custom_ds_metadata = {'num_series': n_series,
                      'num_steps': 24 * 7,
                      'prediction_length': 24 ,
                      'freq': '1H',
                      'start': [pd.Timestamp("01-01-2019", freq='1H') 
                                for _ in range(n_series)]
                     }



sigma = 0.3
target = create_dataset(custom_ds_metadata['num_series'], 
                          custom_ds_metadata['num_steps'],                                                      
                          custom_ds_metadata['prediction_length'],
                        sigma=sigma
                         )
train_ds = ListDataset([{FieldName.TARGET: target, 
                         FieldName.START: start} 
                        for (target, start) in zip(target[:, :-custom_ds_metadata['prediction_length']], 
                                                             custom_ds_metadata['start'])],
                      freq=custom_ds_metadata['freq'])
test_ds = ListDataset([{FieldName.TARGET: target, 
                        FieldName.START: start} 
                       for (target, start) in zip(target, 
                                                            custom_ds_metadata['start'])],
                     freq=custom_ds_metadata['freq'])


train_entry = next(iter(train_ds))
test_entry = next(iter(test_ds))



test_series = to_pandas(test_entry)
train_series = to_pandas(train_entry)

fig, ax = plt.subplots(2, 1, sharex=True, sharey=True, figsize=(10, 7))

train_series.plot(ax=ax[0])
ax[0].grid(which="both")
ax[0].legend(["train series"], loc="upper left")

test_series.plot(ax=ax[1])
ax[1].axvline(train_series.index[-1], color='r') # end of train dataset
ax[1].grid(which="both")
ax[1].legend(["test series", "end of train series"], loc="upper left")

plt.show()



from gluonts.model.simple_feedforward import SimpleFeedForwardEstimator
from gluonts.mx.trainer import Trainer
from gluonts.mx.distribution import NanMixtureOutput, GaussianOutput




estimator = SimpleFeedForwardEstimator(
    num_hidden_dimensions=[24, 24],
    prediction_length=custom_ds_metadata['prediction_length'],
    context_length=2*custom_ds_metadata['prediction_length'],
    freq=custom_ds_metadata['freq'],
    trainer=Trainer(ctx="cpu", 
                    epochs=30, 
                    learning_rate=1e-3, 
                    hybridize=True, 
                    num_batches_per_epoch=n_series//16,
                    batch_size = 16
                   ),
    distr_output=NanMixtureOutput(GaussianOutput()),
    sampling=False
)



predictor = estimator.train(train_ds)





from gluonts.evaluation.backtest import make_evaluation_predictions
forecast_it, ts_it = make_evaluation_predictions(
    dataset=test_ds,  # test dataset
    predictor=predictor,  # predictor
    num_samples=4000,  # number of sample paths we want for evaluation
)
forecasts = list(forecast_it)
tss = list(ts_it)




# get the true means
target_mean = create_dataset(custom_ds_metadata['num_series'], 
                          custom_ds_metadata['num_steps'],                                                      
                          custom_ds_metadata['prediction_length'],
                        true_mean=True
                         )

test_ds_mean = ListDataset([{FieldName.TARGET: target, 
                        FieldName.START: start} 
                       for (target, start) in zip(target_mean, 
                                                            custom_ds_metadata['start'])],
                     freq=custom_ds_metadata['freq'])
true_mean = next(iter(test_ds_mean))["target"]



# get predicted distribution parameter
forecast_entry = forecasts[0]
mu_hat = forecast_entry.distribution.distribution.base_distribution.mu.asnumpy()
sigma_hat = forecast_entry.distribution.distribution.base_distribution.sigma.asnumpy()
nan_prob_hat = forecast_entry.distribution.nan_prob.asnumpy()




# plot the predicted distribution paremeters and the true parameters

ts_entry = tss[0]
plot_length = 80 
legend = ["observations", "median prediction"] 

fig, ax = plt.subplots(1, 1, figsize=(10, 7))
ax.set_ylabel('target', color="g") 
ax.set_xlabel('time') 

ax.tick_params(axis='y', labelcolor="g")

ax.plot(ts_entry.index[-plot_length:], true_mean[-plot_length:], color = "g", label="true target mean")

ax.plot(ts_entry.index[-plot_length:], true_mean[-plot_length:]+sigma, color = "g", linestyle =":")
ax.plot(ts_entry.index[-plot_length:], true_mean[-plot_length:]-sigma, color = "g", linestyle =":")

ax.plot(ts_entry.index[-custom_ds_metadata['prediction_length']:], mu_hat , color = "red", linestyle ="-",
        label="predicted target mean")
ax.plot(ts_entry.index[-custom_ds_metadata['prediction_length']:], mu_hat-sigma_hat, color = "red", linestyle =":")
ax.plot(ts_entry.index[-custom_ds_metadata['prediction_length']:], mu_hat+sigma_hat, color = "red", linestyle =":")

ax.legend()
ax2 = ax.twinx()

ax2.set_ylabel('probablity of missing values', color="b")  

start = pd.Timestamp("01-01-2019", freq='1H')
end = (pd.Timestamp("01-01-2019", freq='1H') + pd.to_timedelta(custom_ds_metadata["num_steps"]-1, unit='h'))

nan_prob = []
nan_gen = nan_prob_gen()
for i in range(custom_ds_metadata["num_steps"]):
    nan_prob.append(next(nan_gen))

ax2.plot(pd.date_range(start,end, freq="h"), nan_prob, color = "blue", linestyle ="-", label="true missing value probability")
ax2.plot(ts_entry.index[-custom_ds_metadata['prediction_length']:], nan_prob_hat, color = "darkred", linestyle ="-",
         label="predicted missing value probability")

ax2.tick_params(axis='y', labelcolor="blue")
ax2.set_ylim([0,1])
ax2.set_xlim([ts_entry.index[-plot_length], end])
fig.tight_layout() 
plt.title("Predicting a NanMixture: sin pattern with gaussian noise and changing missing value probability")
plt.legend()

plt.show()

@lostella lostella merged commit e52864f into awslabs:master Jul 28, 2020
@PascalIversen PascalIversen deleted the distribution_missing_values branch July 28, 2020 15:46
kashif pushed a commit to kashif/gluon-ts that referenced this pull request Oct 10, 2020
* added a deterministic/degenerate distribution

* Corrected the formula of the Standard Deviation of the Mixture Distribution

* Added a distribution to model missing values

* added test script to test log_prob and the gradients and fixed some edge cases

* fixed bug in test script

* corrected true gradients in the test file and edge cases of the log_prob calculations

* fixed edge cases of the gradients

* bugfix CategoricalOutput

* test skip

* style fix

* addressed PR issues

* skipping tests which take too long and rearranging imports

* fixed the output args issue of a NanMixture with a Categorical distribution

* lowered assertion tolerances

* lowered assertion tolerances

* refractoring of the NanMixture tests

* refractoring of the NanMixture tests

* added NanMixture support to the SimpleFeedForward and fixed typing issue

* refractoring

* removing SimpleFeedForward changes

* increased sample size for mixture stddev and mean tests to prevent false alarms

* added random seeds to tests

* increased tol

Co-authored-by: Lorenzo Stella <lorenzostella@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants