Skip to content

Commit

Permalink
Fix typo
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaspinder committed Aug 23, 2022
2 parents 154f0d0 + f1cca9a commit 2e5be70
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
27 changes: 25 additions & 2 deletions gpjax/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,31 @@

__config = None

Identity = dx.Lambda(lambda x: x)
Softplus = dx.Lambda(lambda x: jnp.log(1.0 + jnp.exp(x)))
Identity = dx.Lambda(forward=lambda x: x, inverse=lambda x: x)
Softplus = dx.Lambda(
forward=lambda x: jnp.log(1 + jnp.exp(x)),
inverse=lambda x: jnp.log(jnp.exp(x) - 1.0),
)

# class Softplus(dx.Bijector):
# def __init__(self):
# super().__init__(event_ndims_in=0)

# def forward_and_log_det(self, x):
# softplus = lambda xx: jnp.log(1 + jnp.exp(xx))
# y = softplus(x)
# logdet = softplus(-x)
# return y, logdet

# def inverse_and_log_det(self, y):
# """
# Y = Log[1 + exp{X}] ==> X = Log[exp{Y} - 1]
# ==> dX/dY = exp{Y} / (exp{Y} - 1)
# = 1 / (1 - exp{-Y})
# """
# x = jnp.log(jnp.exp(y) - 1.0)
# logdet = 1 / (1 - jnp.exp(-y))
# return x, logdet


def get_defaults() -> ConfigDict:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def parse_requirements_file(filename):
"optax",
"chex",
"distrax>=0.1.2",
"tensorflow-probability==0.16.0",
"tensorflow-probability>=0.16.0",
"tqdm>=4.0.0",
"ml-collections==0.1.0",
"jaxtyping",
Expand Down
4 changes: 2 additions & 2 deletions tests/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,5 +253,5 @@ def test_output(num_datapoints, likelihood):
a_constrainers, a_unconstrainers = build_transforms(augmented_params)
assert "test_param" in list(a_constrainers.keys())
assert "test_param" in list(a_unconstrainers.keys())
assert a_constrainers["test_param"](1.0) == 1.0
assert a_unconstrainers["test_param"](1.0) == 1.0
assert a_constrainers["test_param"](jnp.array([1.0])) == 1.0
assert a_unconstrainers["test_param"](jnp.array([1.0])) == 1.0

0 comments on commit 2e5be70

Please sign in to comment.