Skip to content

Commit

Permalink
Add soft clamping to affine transform, otherwise the loss explodes ev…
Browse files Browse the repository at this point in the history
…en for simple cases
  • Loading branch information
Radev committed May 14, 2024
1 parent c8005fb commit faae8df
Showing 1 changed file with 10 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,31 +1,36 @@

import numpy as np
from math import pi as PI_CONST

from keras import ops

from bayesflow.experimental.types import Tensor
from .transform import Transform


class AffineTransform(Transform):

def __init__(self, clamp_factor=1.9):
self.clamp_factor = clamp_factor

def split_parameters(self, parameters: Tensor) -> dict[str, Tensor]:
scale, shift = ops.split(parameters, 2, axis=-1)

return {"scale": scale, "shift": shift}

def constrain_parameters(self, parameters: dict[str, Tensor]) -> dict[str, Tensor]:
shift = np.log(np.e - 1)
parameters["scale"] = ops.softplus(parameters["scale"] + shift)
s = (2.0 * self.clamp_factor / PI_CONST) * ops.atan(parameters["scale"] / self.soft_clamp)
parameters["scale"] = ops.exp(s)

return parameters

def forward(self, x: Tensor, parameters: dict[str, Tensor]) -> (Tensor, Tensor):
z = parameters["scale"] * x + parameters["shift"]
log_det = ops.mean(ops.log(parameters["scale"]), axis=-1)
log_det = ops.mean(parameters["scale"], axis=-1)

return z, log_det

def inverse(self, z: Tensor, parameters: dict[str, Tensor]) -> (Tensor, Tensor):
x = (z - parameters["shift"]) / parameters["scale"]
log_det = -ops.mean(ops.log(parameters["scale"]), axis=-1)
log_det = -ops.mean(parameters["scale"], axis=-1)

return x, log_det

0 comments on commit faae8df

Please sign in to comment.