Skip to content

Commit

Permalink
Reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasckng committed Jul 26, 2024
1 parent 6a35792 commit 1bcf32c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ def __repr__(self):

def __init__(self, parameter_names: list[str], **kwargs):
super().__init__(parameter_names)
assert self.n_dim == 1, "StandardNormalDistribution needs to be 1D distributions"
assert (
self.n_dim == 1
), "StandardNormalDistribution needs to be 1D distributions"

def sample(
self, rng_key: PRNGKeyArray, n_samples: int
Expand All @@ -131,7 +133,8 @@ def sample(

def log_prob(self, x: dict[str, Float]) -> Float:
variable = x[self.parameter_names[0]]
return -0.5 * variable ** 2 - 0.5 * jnp.log(2 * jnp.pi)
return -0.5 * variable**2 - 0.5 * jnp.log(2 * jnp.pi)


class SequentialTransform(Prior):
"""
Expand Down Expand Up @@ -249,6 +252,7 @@ def __init__(
],
)


@jaxtyped(typechecker=typechecker)
class PeriodicUniform(SequentialTransform):
xmin: float
Expand Down
2 changes: 2 additions & 0 deletions src/jimgw/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def __init__(
super().__init__(name_mapping)
self.transform_func = lambda x: 1 / (1 + jnp.exp(-x))


class Modulo(UnivariateTransform):
"""
Modulo transform following
Expand All @@ -158,6 +159,7 @@ def __init__(
self.modulo = modulo
self.transform_func = lambda x: jnp.mod(x, self.modulo)


class ArcSine(UnivariateTransform):
"""
ArcSine transformation
Expand Down

0 comments on commit 1bcf32c

Please sign in to comment.