Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: removed input re-normalization by rin inside io_processor #2160

Merged
merged 2 commits into from
Jan 13, 2024

Conversation

FourierMourier
Copy link
Contributor

Fixes #2159.

Summary

Now, the io_processor wrapper doesn't clone tensor before applying .rin to it, which leads to overwriting initial input tensor and to incorrect results during the .predict method. By applying .clone() to the args[0][0] inside io_processor, we can prevent such behavior.

Other Information

To reproduce previous and current behavior you can use the following snippet:

from functools import wraps
import torch

from darts.models import NBEATSModel
from darts.models.forecasting.nbeats import _NBEATSModule
from darts.timeseries import TimeSeries

from typing import Tuple


class Counter:
    def __init__(self):
        self.value = 0

    def incr(self):
        self.value += 1

    def reset(self):
        self.value = 0


LEAK_COUNTER = Counter()
SAFE_COUNTER = Counter()


def io_processor(forward):
    @wraps(forward)
    def forward_wrapper(self, *args, **kwargs):
        if not self.use_reversible_instance_norm:
            return forward(self, *args, **kwargs)

        # x is input batch tuple which by definition has the past features in the first element starting with the
        # first n target features
        # x: Tuple = args[0][0]
        # assuming `args[0][0]` is torch.Tensor we could clone it to prevent memory leak and target over-normalization
        x: Tuple = args[0][0].clone()
        # apply reversible instance normalization
        x[:, :, : self.n_targets] = self.rin(x[:, :, : self.n_targets])
        # run the forward pass
        out = forward(self, *((x, *args[0][1:]), *args[1:]), **kwargs)
        # inverse transform target output back to original scale; by definition the first output
        if isinstance(out, tuple):
            return self.rin.inverse(out[0]), *out[1:]
        else:
            return self.rin.inverse(out)

    return forward_wrapper


def io_processor_with_leak(forward):
    @wraps(forward)
    def forward_wrapper(self, *args, **kwargs):
        if not self.use_reversible_instance_norm:
            return forward(self, *args, **kwargs)

        # x is input batch tuple which by definition has the past features in the first element starting with the
        # first n target features
        x: Tuple = args[0][0]
        # apply reversible instance normalization
        x[:, :, : self.n_targets] = self.rin(x[:, :, : self.n_targets])
        # run the forward pass
        out = forward(self, *((x, *args[0][1:]), *args[1:]), **kwargs)
        # inverse transform target output back to original scale; by definition the first output
        if isinstance(out, tuple):
            return self.rin.inverse(out[0]), *out[1:]
        else:
            return self.rin.inverse(out)

    return forward_wrapper


def run_nbeats_forward(self, x_in):
    x, _ = x_in

    # if x1, x2,... y1, y2... is one multivariate ts containing x and y, and a1, a2... one covariate ts
    # we reshape into x1, y1, a1, x2, y2, a2... etc
    x = torch.reshape(x, (x.shape[0], self.input_chunk_length_multi, 1))
    # squeeze last dimension (because model is univariate)
    x = x.squeeze(dim=2)

    # One vector of length target_length per parameter in the distribution
    y = torch.zeros(
        x.shape[0],
        self.target_length,
        self.nr_params,
        device=x.device,
        dtype=x.dtype,
    )

    for stack in self.stacks_list:
        # compute stack output
        stack_residual, stack_forecast = stack(x)

        # add stack forecast to final output
        y = y + stack_forecast

        # set current stack residual as input for next stack
        x = stack_residual

    # In multivariate case, we get a result [x1_param1, x1_param2], [y1_param1, y1_param2], [x2..], [y2..], ...
    # We want to reshape to original format. We also get rid of the covariates and keep only the target dimensions.
    # The covariates are by construction added as extra time series on the right side. So we need to get rid of this
    # right output (keeping only :self.output_dim).
    y = y.view(
        y.shape[0], self.output_chunk_length, self.input_dim, self.nr_params
    )[:, :, : self.output_dim, :]

    return y


def run_produce_predict_output(self, x, counter: Counter):
    if self.likelihood:
        output = self(x)
        if self.predict_likelihood_parameters:
            return self.likelihood.predict_likelihood_parameters(output)
        else:
            return self.likelihood.sample(output)
    else:
        step = counter.value
        if step == 0:  # prevent overlap with predicting pbar
            print(f"\nproducing predictions of {self.__class__.__name__}...")
        print(f"step {step}: before forward = {x[0].view(-1)}")
        out = self(x).squeeze(dim=-1)
        print(f"step {step}: after forward = {x[0].view(-1)}")
        counter.incr()
        return out


class _NBEATSModuleNoLeak(_NBEATSModule):
    @io_processor
    def forward(self, x_in: Tuple):
        return run_nbeats_forward(self, x_in)

    def _produce_predict_output(self, x: Tuple) -> torch.Tensor:
        return run_produce_predict_output(self, x, SAFE_COUNTER)


class _NBEATSModuleWithLeak(_NBEATSModule):
    @io_processor_with_leak
    def forward(self, x_in: Tuple):
        return run_nbeats_forward(self, x_in)

    def _produce_predict_output(self, x: Tuple) -> torch.Tensor:
        return run_produce_predict_output(self, x, LEAK_COUNTER)


class NBEATSModelInitial(NBEATSModel):
    def __init__(self, input_chunk_length: int, output_chunk_length: int, **kwargs):
        super().__init__(input_chunk_length, output_chunk_length, **kwargs)

    def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module:
        # samples are made of (past_target, past_covariates, future_target)
        input_dim = train_sample[0].shape[1] + (
            train_sample[1].shape[1] if train_sample[1] is not None else 0
        )
        output_dim = train_sample[-1].shape[1]
        nr_params = 1 if self.likelihood is None else self.likelihood.num_parameters

        return _NBEATSModuleWithLeak(
            input_dim=input_dim,
            output_dim=output_dim,
            nr_params=nr_params,
            generic_architecture=self.generic_architecture,
            num_stacks=self.num_stacks,
            num_blocks=self.num_blocks,
            num_layers=self.num_layers,
            layer_widths=self.layer_widths,
            expansion_coefficient_dim=self.expansion_coefficient_dim,
            trend_polynomial_degree=self.trend_polynomial_degree,
            batch_norm=self.batch_norm,
            dropout=self.dropout,
            activation=self.activation,
            **self.pl_module_params,
        )


class NBEATSModelNoLeak(NBEATSModel):
    def __init__(self, input_chunk_length: int, output_chunk_length: int, **kwargs):
        super().__init__(input_chunk_length, output_chunk_length, **kwargs)

    def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module:
        # samples are made of (past_target, past_covariates, future_target)
        input_dim = train_sample[0].shape[1] + (
            train_sample[1].shape[1] if train_sample[1] is not None else 0
        )
        output_dim = train_sample[-1].shape[1]
        nr_params = 1 if self.likelihood is None else self.likelihood.num_parameters

        return _NBEATSModuleNoLeak(
            input_dim=input_dim,
            output_dim=output_dim,
            nr_params=nr_params,
            generic_architecture=self.generic_architecture,
            num_stacks=self.num_stacks,
            num_blocks=self.num_blocks,
            num_layers=self.num_layers,
            layer_widths=self.layer_widths,
            expansion_coefficient_dim=self.expansion_coefficient_dim,
            trend_polynomial_degree=self.trend_polynomial_degree,
            batch_norm=self.batch_norm,
            dropout=self.dropout,
            activation=self.activation,
            **self.pl_module_params,
        )


def run_model(model: NBEATSModel):
    print('*' * 40)
    print(f"running {model.__class__.__name__}")
    sample_data_values = torch.linspace(0.1, 1.9, 19)
    # sample_data_indices = torch.arange(0, 9, 1)
    darts_series = TimeSeries.from_values(sample_data_values.cpu().numpy())
    x_in = (sample_data_values, None)
    model.fit(series=[darts_series])
    input_slice = sample_data_values[:model.input_chunk_length]
    input_slice_copy = input_slice.clone()
    input_slice = input_slice.unsqueeze(0).unsqueeze(-1)
    # call forward()
    _ = model.model((input_slice, None))
    print((f'\ninput before forward: {input_slice_copy}\n'
           f'after forward: {input_slice.view(*input_slice_copy.shape)}\n'))
    # call predict:
    input_darts_series = TimeSeries.from_values(input_slice_copy.cpu().numpy())
    LEAK_COUNTER.reset()
    pred = model.predict(series=input_darts_series, n=2)
    pred_data = pred._xa.data.reshape(-1)
    print(pred_data)
    return


def main():
    m_kwargs = dict(input_chunk_length=3, output_chunk_length=1, use_reversible_instance_norm=True, n_epochs=10)
    initial_one = NBEATSModelInitial(**m_kwargs)
    no_leak_model = NBEATSModelNoLeak(**m_kwargs)
    run_model(initial_one)
    print("\n\n")
    run_model(no_leak_model)


if __name__ == '__main__':
    main()

which will produce the output below:

****************************************
running NBEATSModelInitial
...
Epoch 9: 100%|██████████| 1/1 [00:00<00:00,  3.60it/s, train_loss=0.000421]
`Trainer.fit` stopped: `max_epochs=10` reached.

input before forward: tensor([0.1000, 0.2000, 0.3000])
after forward: tensor([-1.2301, -0.0036,  1.2228], grad_fn=<ViewBackward0>)
...
Predicting DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]
producing predictions of _NBEATSModuleWithLeak...
step 0: before forward = tensor([0.1000, 0.2000, 0.3000])
step 0: after forward = tensor([-1.2301, -0.0036,  1.2228])
step 1: before forward = tensor([-0.0036,  1.2228,  0.3983])
step 1: after forward = tensor([-1.0692,  1.3384, -0.2801])
Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 27.82it/s]
[0.39834473 1.602938  ]



****************************************
running NBEATSModelNoLeak
...
Epoch 9: 100%|██████████| 1/1 [00:00<00:00,  3.65it/s, train_loss=0.000684]
`Trainer.fit` stopped: `max_epochs=10` reached.

input before forward: tensor([0.1000, 0.2000, 0.3000])
after forward: tensor([0.1000, 0.2000, 0.3000])
...
Predicting DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]
producing predictions of _NBEATSModuleNoLeak...
step 0: before forward = tensor([0.1000, 0.2000, 0.3000])
step 0: after forward = tensor([0.1000, 0.2000, 0.3000])
step 1: before forward = tensor([0.2000, 0.3000, 0.4237])
step 1: after forward = tensor([0.2000, 0.3000, 0.4237])
Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 23.18it/s]
[0.42366236 0.5575354 ]

@codecov-commenter
Copy link

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (ea79679) 93.92% compared to head (7a33d75) 93.91%.

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2160      +/-   ##
==========================================
- Coverage   93.92%   93.91%   -0.01%     
==========================================
  Files         135      135              
  Lines       13335    13321      -14     
==========================================
- Hits        12525    12511      -14     
  Misses        810      810              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great, thanks a lot @FourierMourier. I observed this unexpected behavior as well recently when predicting with n > output_chunk_length.

Do we need to clone during training as well?
Just want to make sure we avoid any unnecessary operations.

@FourierMourier
Copy link
Contributor Author

@dennisbader Hi. Hmm, I think if the dataset is reconstructed every epoch it might not be necessary (assuming each batch is unique and does NOT overlap with the others or, at least, there's NO possibility to have so) but I'd rather keep it just to make sure the behavior during training is still the same as during inference

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked again and the overhead is negligible for training.

Thanks again for this @FourierMourier 🚀

@dennisbader dennisbader merged commit cb724d1 into unit8co:master Jan 13, 2024
8 of 9 checks passed
@FourierMourier
Copy link
Contributor Author

Glad it works fine

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] Input gets re-normalized several times during .predict method
3 participants