Skip to content

Commit

Permalink
fix the sign method of affine transformation (#613)
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella authored Feb 17, 2020
1 parent 887ce26 commit 48a1b10
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions src/gluonts/distribution/bijection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# permissions and limitations under the License.

# Standard library imports
from typing import Optional
from typing import Optional, Union

# Third-party imports
import numpy as np
Expand Down Expand Up @@ -74,7 +74,7 @@ def event_dim(self) -> int:
raise NotImplementedError()

@property
def sign(self) -> Tensor:
def sign(self) -> Union[float, Tensor]:
"""
Return the sign of the Jacobian's determinant.
"""
Expand All @@ -96,6 +96,7 @@ class InverseBijection(Bijection):

@validated()
def __init__(self, bijection: Bijection) -> None:
super().__init__(self)
self._bijection = bijection

def f(self, x: Tensor) -> Tensor:
Expand All @@ -115,7 +116,7 @@ def event_dim(self) -> int:
return self._bijection.event_dim

@property
def sign(self) -> Tensor:
def sign(self) -> Union[float, Tensor]:
return self._bijection.sign


Expand All @@ -134,7 +135,7 @@ def event_dim(self) -> int:
return 0

@property
def sign(self) -> Tensor:
def sign(self) -> float:
return 1.0


Expand All @@ -153,7 +154,7 @@ def event_dim(self) -> int:
return 0

@property
def sign(self) -> Tensor:
def sign(self) -> float:
return 1.0


Expand Down Expand Up @@ -183,7 +184,7 @@ def event_dim(self) -> int:
return 0

@property
def sign(self) -> Tensor:
def sign(self) -> float:
return 1.0


Expand All @@ -207,6 +208,7 @@ class AffineTransformation(Bijection):
def __init__(
self, loc: Optional[Tensor] = None, scale: Optional[Tensor] = None
) -> None:
super().__init__(self)
self.loc = loc
self.scale = scale

Expand All @@ -233,8 +235,8 @@ def log_abs_det_jac(self, x: Tensor, y: Tensor) -> Tensor:
return F.zeros_like(x)

@property
def sign(self):
return self.scale.sign()
def sign(self) -> Union[float, Tensor]:
return 1.0 if self.scale is None else self.scale.sign()

@property
def event_dim(self) -> int:
Expand Down

0 comments on commit 48a1b10

Please sign in to comment.