From 0e654312836dd32ea557d547e6e8c7d6128de002 Mon Sep 17 00:00:00 2001 From: Pedro Eduardo Mercado Lopez Date: Tue, 22 Aug 2023 10:51:26 +0200 Subject: [PATCH] add default values to lower/upper bounds --- src/gluonts/torch/distributions/truncated_normal.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gluonts/torch/distributions/truncated_normal.py b/src/gluonts/torch/distributions/truncated_normal.py index fc09d128c0..44d8b58b9d 100644 --- a/src/gluonts/torch/distributions/truncated_normal.py +++ b/src/gluonts/torch/distributions/truncated_normal.py @@ -86,8 +86,8 @@ def __init__( self, loc: torch.Tensor, scale: torch.Tensor, - min: Union[torch.Tensor, float], - max: Union[torch.Tensor, float], + min: Union[torch.Tensor, float] = -1.0, + max: Union[torch.Tensor, float] = 1.0, upscale: Union[torch.Tensor, float] = 5.0, tanh_loc: bool = False, ): @@ -240,8 +240,8 @@ class TruncatedNormalOutput(DistributionOutput): @validated() def __init__( self, - min: float, - max: float, + min: float = -1.0, + max: float = 1.0, upscale: float = 5.0, tanh_loc: bool = False, ) -> None: