Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify QuantileOutput and DistributionOutput #3093

Merged
merged 25 commits into from
Jan 10, 2024
Merged

Conversation

shchur
Copy link
Contributor

@shchur shchur commented Dec 27, 2023

Issue #, if available: closes #3083

Description of changes:

  • Move loss computation into Output object and remove all losses defined in gluonts.torch.module.loss.
  • Make following torch models compatible with both DistributionOutput and QuantileOutput:
    • SimpleFeedForward
    • TemporalFusionTransformer
    • DLinear
    • PatchTST
    • LagTST
  • Change return type of the forward method for the following models to Tuple[Tuple[Tensor, ...], Tensor, Tensor]:
    • MQCNN (MXNet)
    • MQRNN (MXNet)
    • TemporalFusionTransformer (MXNet)
    • TemporalFusionTransformer (PyTorch)
  • Update the logic inside QuantileForecastGenerator to support the new unified signature of forward method
  • Replace predict_to_numpy method with to_numpy

cc @kashif

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

Please tag this pr with at least one of these labels to make our release process faster: BREAKING, new feature, bug fix, other change, dev setup

@shchur shchur marked this pull request as draft December 27, 2023 13:19
@shchur shchur added BREAKING This is a breaking change (one of pr required labels) enhancement New feature or request labels Dec 27, 2023
@shchur shchur changed the title [WIP] Unify QuantileOutput and DistributionOutput Unify QuantileOutput and DistributionOutput Dec 27, 2023
@shchur shchur marked this pull request as ready for review December 27, 2023 15:25
@shchur shchur requested a review from lostella December 27, 2023 15:26
@lostella lostella added models This item concerns models implementations mxnet This concerns the MXNet side of GluonTS torch This concerns the PyTorch side of GluonTS labels Dec 28, 2023
Copy link
Contributor

@lostella lostella left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @shchur, left some comments but I’ll take a deeper look

src/gluonts/model/forecast_generator.py Outdated Show resolved Hide resolved
src/gluonts/torch/distributions/distribution_output.py Outdated Show resolved Hide resolved
src/gluonts/torch/distributions/quantile.py Outdated Show resolved Hide resolved
Comment on lines 118 to 122
(outputs,), loc, scale = prediction_net(*inputs.values())
if scale is not None:
outputs = outputs * scale[..., None]
if loc is not None:
outputs = outputs + loc[..., None]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering: would it be better to turn to_numpy here? Otherwise the type of the output of prediction_net depend on the type of prediction_net, and we're using indexing/multiplication/addition without really knowing if it will work. Of course, it's just indexing/multiplication/addition, but I'm wondering if it would be somehow clearer if these objects were np.ndarray already here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, your suggestion feels cleaner. Implemented it.

yield QuantileForecast(
output,
output.T,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why move the transposition here, from inside the model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QuantileForecast expects array of shape [num_quantiles, prediction_length], but all models usually produce the output of shape [batch_size, prediction_length, *additional_dims] - both in case of DistributionOutput and QuantileOutput. I think it's better to keep model output shape consistent and do the transpose here.

@@ -77,7 +78,7 @@ def assert_shapes_and_dtypes(tensors, shapes, dtypes):
TemporalFusionTransformerModel(
context_length=24,
prediction_length=12,
quantiles=[0.2, 0.25, 0.5, 0.9, 0.95],
distr_output=QuantileOutput([0.2, 0.25, 0.5, 0.9, 0.95]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this class now also supports other parametric distribution families, right? Should some test case be added for that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sorry, somehow I missed that

@lostella lostella mentioned this pull request Jan 9, 2024
@@ -292,7 +292,7 @@ def __init__(
es_num_samples: int = 50,
beta: float = 1.0,
) -> None:
super().__init__(self)
super().__init__(self, beta=beta)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then I guess there's no need to set self.beta further below (line 306)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

@shchur shchur merged commit c99dafa into awslabs:dev Jan 10, 2024
19 checks passed
maxc01 added a commit to maxc01/gluonts that referenced this pull request Jan 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
BREAKING This is a breaking change (one of pr required labels) enhancement New feature or request models This item concerns models implementations mxnet This concerns the MXNet side of GluonTS torch This concerns the PyTorch side of GluonTS
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Unify QuantileOutput and DistributionOutput
2 participants