Skip to content

Commit

Permalink
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
Browse files Browse the repository at this point in the history
…ss' into transform
  • Loading branch information
thomasckng committed Aug 1, 2024
2 parents 8368d00 + 47af9cf commit b4f6052
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 142 deletions.
24 changes: 15 additions & 9 deletions src/jimgw/jim.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jaxtyping import Array, Float, PRNGKeyArray

from jimgw.base import LikelihoodBase
from jimgw.prior import Prior
from jimgw.prior import Prior, trace_prior_parent
from jimgw.transforms import BijectiveTransform, NtoMTransform


Expand All @@ -22,6 +22,7 @@ class Jim(object):
# Name of parameters to sample from
sample_transforms: list[BijectiveTransform]
likelihood_transforms: list[NtoMTransform]
parameter_names: list[str]
sampler: Sampler

def __init__(
Expand All @@ -37,16 +38,18 @@ def __init__(

self.sample_transforms = sample_transforms
self.likelihood_transforms = likelihood_transforms
self.parameter_names = prior.parameter_names

if len(sample_transforms) == 0:
print(
"No sample transforms provided. Using prior parameters as sampling parameters"
)
print("No sample transforms provided. Using prior parameters as sampling parameters")
else:
print("Using sample transforms")
for transform in sample_transforms:
self.parameter_names = transform.propagate_name(self.parameter_names)

if len(likelihood_transforms) == 0:
print(
"No likelihood transforms provided. Using prior parameters as likelihood parameters"
)
print("No likelihood transforms provided. Using prior parameters as likelihood parameters")


seed = kwargs.get("seed", 0)

Expand Down Expand Up @@ -94,12 +97,15 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict):
prior = self.prior.log_prob(named_params) + transform_jacobian
for transform in self.likelihood_transforms:
named_params = transform.forward(named_params)
return self.likelihood.evaluate(named_params, data) + prior
named_params = jax.tree.map(lambda x:x[0], named_params) # This [0] should be consolidate
return self.likelihood.evaluate(named_params, data) + prior[0] # This prior [0] should be consolidate

def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])):
if initial_guess.size == 0:
initial_guess_named = self.prior.sample(key, self.Sampler.n_chains)
initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T
for transform in self.sample_transforms:
initial_guess_named = jax.vmap(transform.forward)(initial_guess_named)
initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T[0] # This [0] should be consolidate
self.Sampler.sample(initial_guess, None) # type: ignore

def maximize_likelihood(
Expand Down
147 changes: 14 additions & 133 deletions src/jimgw/transforms.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,12 @@
from abc import ABC
from abc import ABC, abstractmethod
from typing import Callable

import jax
import jax.numpy as jnp
from chex import assert_rank
from beartype import beartype as typechecker
from jaxtyping import Float, Array, jaxtyped

from jimgw.single_event.utils import (
Mc_q_to_m1_m2,
m1_m2_to_Mc_q,
q_to_eta,
eta_to_q,
ra_dec_to_zenith_azimuth,
zenith_azimuth_to_ra_dec,
euler_rotation,
)


class Transform(ABC):
"""
Expand Down Expand Up @@ -270,6 +261,7 @@ def __init__(

@jaxtyped(typechecker=typechecker)
class BoundToBound(BijectiveTransform):

"""
Bound to bound transformation
"""
Expand Down Expand Up @@ -308,7 +300,7 @@ def __init__(
for i in range(len(name_mapping[1]))
}


@jaxtyped(typechecker=typechecker)
class BoundToUnbound(BijectiveTransform):
"""
Bound to unbound transformation
Expand All @@ -323,13 +315,13 @@ def __init__(
original_lower_bound: Float,
original_upper_bound: Float,
):

def logit(x):
return jnp.log(x / (1 - x))

super().__init__(name_mapping)
self.original_lower_bound = original_lower_bound
self.original_upper_bound = original_upper_bound
self.original_lower_bound = jnp.atleast_1d(original_lower_bound)
self.original_upper_bound = jnp.atleast_1d(original_upper_bound)

self.transform_func = lambda x: {
name_mapping[1][i]: logit(
Expand All @@ -339,13 +331,17 @@ def logit(x):
for i in range(len(name_mapping[0]))
}
self.inverse_transform_func = lambda x: {
name_mapping[0][i]: (self.original_upper_bound - self.original_lower_bound)
/ (1 + jnp.exp(-x[name_mapping[1][i]]))
name_mapping[0][i]: (
self.original_upper_bound - self.original_lower_bound
)
/ (
1
+ jnp.exp(-x[name_mapping[1][i]])
)
+ self.original_lower_bound[i]
for i in range(len(name_mapping[1]))
}


class SingleSidedUnboundTransform(BijectiveTransform):
"""
Unbound upper limit transformation
Expand All @@ -372,121 +368,6 @@ def __init__(
}


class ChirpMassMassRatioToComponentMassesTransform(BijectiveTransform):
"""
Transform chirp mass and mass ratio to component masses
Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.
"""

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
):
super().__init__(name_mapping)

def named_transform(x):
Mc = x[name_mapping[0][0]]
q = x[name_mapping[0][1]]
m1, m2 = Mc_q_to_m1_m2(Mc, q)
return {name_mapping[1][0]: m1, name_mapping[1][1]: m2}

self.transform_func = named_transform

def named_inverse_transform(x):
m1 = x[name_mapping[1][0]]
m2 = x[name_mapping[1][1]]
Mc, q = m1_m2_to_Mc_q(m1, m2)
return {name_mapping[0][0]: Mc, name_mapping[0][1]: q}

self.inverse_transform_func = named_inverse_transform


class ChirpMassMassRatioToChirpMassSymmetricMassRatioTransform(BijectiveTransform):
"""
Transform chirp mass and mass ratio to chirp mass and symmetric mass ratio
Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.
"""

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
):
super().__init__(name_mapping)

def named_transform(x):
Mc = x[name_mapping[0][0]]
q = x[name_mapping[0][1]]
eta = q_to_eta(q)
return {name_mapping[1][0]: Mc, name_mapping[1][1]: eta}

self.transform_func = named_transform

def named_inverse_transform(x):
Mc = x[name_mapping[1][0]]
eta = x[name_mapping[1][1]]
q = eta_to_q(Mc, eta)
return {name_mapping[0][0]: Mc, name_mapping[0][1]: q}

self.inverse_transform_func = named_inverse_transform


class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform):
"""
Transform sky frame to detector frame sky position
Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.
"""

gmst: Float
rotation: Float[Array, " 3 3"]
rotation_inv: Float[Array, " 3 3"]

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
gmst: Float,
delta_x: Float,
):
super().__init__(name_mapping)

self.gmst = gmst
self.rotation = euler_rotation(delta_x)
self.rotation_inv = jnp.linalg.inv(self.rotation)

def named_transform(x):
ra = x[name_mapping[0][0]]
dec = x[name_mapping[0][1]]
zenith, azimuth = ra_dec_to_zenith_azimuth(
ra, dec, self.gmst, self.rotation
)
return {name_mapping[1][0]: zenith, name_mapping[1][1]: azimuth}

self.transform_func = named_transform

def named_inverse_transform(x):
zenith = x[name_mapping[1][0]]
azimuth = x[name_mapping[1][1]]
ra, dec = zenith_azimuth_to_ra_dec(
zenith, azimuth, self.gmst, self.rotation_inv
)
return {name_mapping[0][0]: ra, name_mapping[0][1]: dec}

self.inverse_transform_func = named_inverse_transform


# class PowerLawTransform(UnivariateTransform):
# """
Expand Down
17 changes: 17 additions & 0 deletions test/integration/test_GW150914.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from jimgw.single_event.detector import H1, L1
from jimgw.single_event.likelihood import TransientLikelihoodFD
from jimgw.single_event.waveform import RippleIMRPhenomD
from jimgw.transforms import BoundToUnbound
from flowMC.strategy.optimization import optimization_Adam

jax.config.update("jax_enable_x64", True)
Expand Down Expand Up @@ -64,6 +65,21 @@
dec_prior,
]
)

sample_transforms = [
BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0),
BoundToUnbound(name_mapping = [["eta"], ["eta_unbounded"]], original_lower_bound=0.125, original_upper_bound=0.25),
BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0),
BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0),
BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0),
BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05),
BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=-jnp.pi/2, original_upper_bound=jnp.pi/2),
BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = [["ra"], ["ra_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = [["dec"], ["dec_unbounded"]],original_lower_bound=0.0, original_upper_bound=jnp.pi)
]

likelihood = TransientLikelihoodFD(
[H1, L1],
waveform=RippleIMRPhenomD(),
Expand All @@ -88,6 +104,7 @@
jim = Jim(
likelihood,
prior,
sample_transforms=sample_transforms,
n_loop_training=n_loop_training,
n_loop_production=1,
n_local_steps=5,
Expand Down

0 comments on commit b4f6052

Please sign in to comment.