Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nonnegative predictions for deepar mxnet #2957

Merged
merged 14 commits into from
Aug 11, 2023

Conversation

melopeo
Copy link
Contributor

@melopeo melopeo commented Aug 9, 2023

Issue #, if available:

Description of changes:
Add functionality to generate nonnegative prediction samples for DeepAR. This is applied only to final samples for prediction.

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

Please tag this pr with at least one of these labels to make our release process faster: BREAKING, new feature, bug fix, other change, dev setup

@melopeo melopeo added the enhancement New feature or request label Aug 9, 2023
melopeo and others added 3 commits August 10, 2023 13:51
Improve readability on testing

Co-authored-by: Lorenzo Stella <lorenzostella@gmail.com>
Hardcode nonnegative parameter in tests

Co-authored-by: Lorenzo Stella <lorenzostella@gmail.com>
"""

if self.nonnegative_pred_samples:
return F.Activation(samples, act_type="relu")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if this works both in symbolic vs eager mode

Suggested change
return F.Activation(samples, act_type="relu")
return F.relu(samples)


dataset_train, dataset_test = datasets
predictor = estimator.train(dataset_train)
forecasts = list(predictor.predict(dataset_test))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To ensure the added feature also works in symbolic mode, you can call as_symbol_block_predictor on the predictor object, see

def as_symbol_block_predictor(

@melopeo melopeo merged commit 589281e into awslabs:dev Aug 11, 2023
18 of 21 checks passed
@melopeo melopeo deleted the nonnegative_predictions_for_deepar-mxnet branch August 11, 2023 11:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants