diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index e2d98f71..7dd50018 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -105,7 +105,9 @@ def __repr__(self): def __init__(self, parameter_names: list[str], **kwargs): super().__init__(parameter_names) - assert self.n_dim == 1, "StandardNormalDistribution needs to be 1D distributions" + assert ( + self.n_dim == 1 + ), "StandardNormalDistribution needs to be 1D distributions" def sample( self, rng_key: PRNGKeyArray, n_samples: int @@ -131,7 +133,8 @@ def sample( def log_prob(self, x: dict[str, Float]) -> Float: variable = x[self.parameter_names[0]] - return -0.5 * variable ** 2 - 0.5 * jnp.log(2 * jnp.pi) + return -0.5 * variable**2 - 0.5 * jnp.log(2 * jnp.pi) + class SequentialTransform(Prior): """ @@ -249,6 +252,7 @@ def __init__( ], ) + @jaxtyped(typechecker=typechecker) class PeriodicUniform(SequentialTransform): xmin: float diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 05dacdb1..8a4787c5 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -138,6 +138,7 @@ def __init__( super().__init__(name_mapping) self.transform_func = lambda x: 1 / (1 + jnp.exp(-x)) + class Modulo(UnivariateTransform): """ Modulo transform following @@ -158,6 +159,7 @@ def __init__( self.modulo = modulo self.transform_func = lambda x: jnp.mod(x, self.modulo) + class ArcSine(UnivariateTransform): """ ArcSine transformation