Skip to content

Commit

Permalink
Merge pull request #35 from kazewong/small_corrections
Browse files Browse the repository at this point in the history
Correcting prior and adding pv2
  • Loading branch information
kazewong authored Aug 19, 2023
2 parents f9de92f + dc3b4ae commit f67f281
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,5 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array:
return samples # TODO: remember to cast this to a named array

def log_prob(self, x: Array) -> Float:
output = jax.lax.cond(not jnp.where((x>=self.xmax) | (x<=self.xmin))[0].any(), lambda: 0., lambda: -jnp.inf)
output = jnp.sum(jnp.where((x>=self.xmax) | (x<=self.xmin), jnp.zeros_like(x)-jnp.inf, jnp.zeros_like(x)))
return output + jnp.sum(jnp.log(1./(self.xmax-self.xmin)))
23 changes: 20 additions & 3 deletions src/jimgw/waveform.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from jaxtyping import Array
from ripple.waveforms.IMRPhenomD import gen_IMRPhenomD_polar
from ripple.waveforms.IMRPhenomD import gen_IMRPhenomD_hphc
from ripple.waveforms.IMRPhenomPv2 import gen_IMRPhenomPv2_hphc
import jax.numpy as jnp
from abc import ABC

Expand All @@ -22,10 +23,26 @@ def __call__(self, frequency: Array, params: dict) -> dict:
output = {}
ra = params['ra']
dec = params['dec']
theta = [params['M_c'], params['eta'], params['s1_z'], params['s2_z'], params['d_L'], 0, params['phase_c'], params['iota'], params['psi'], ra, dec]
hp, hc = gen_IMRPhenomD_polar(frequency, theta, self.f_ref)
theta = [params['M_c'], params['eta'], params['s1_z'], params['s2_z'], params['d_L'], 0, params['phase_c'], params['iota']]
hp, hc = gen_IMRPhenomD_hphc(frequency, theta, self.f_ref)
output['p'] = hp
output['c'] = hc
return output

class RippleIMRPhenomPv2(Waveform):

f_ref: float

def __init__(self, f_ref: float = 20.0):
self.f_ref = f_ref

def __call__(self, frequency: Array, params: dict) -> Array:
output = {}
theta = [params['M_c'], params['eta'], 0.0, 0.0, params['s1_z'],
0.0, 0.0, params['s2_z'],
params['d_L'], 0, params['phase_c'], params['iota']]
hp, hc = gen_IMRPhenomPv2_hphc(frequency, theta, self.f_ref)
output['p'] = hp
output['c'] = hc
return output

0 comments on commit f67f281

Please sign in to comment.