From e5c6796263fa901c37f86208c2f40fd61dcb415f Mon Sep 17 00:00:00 2001 From: Thibeau Wouters Date: Wed, 8 May 2024 02:51:20 -0700 Subject: [PATCH 1/8] added TF2 waveform --- example/GW150914_TaylorF2.py | 155 ++++++++++++++++++++++ example/GW170817_TaylorF2.py | 201 +++++++++++++++++++++++++++++ src/jimgw/single_event/waveform.py | 89 ++++++++++++- 3 files changed, 444 insertions(+), 1 deletion(-) create mode 100644 example/GW150914_TaylorF2.py create mode 100644 example/GW170817_TaylorF2.py diff --git a/example/GW150914_TaylorF2.py b/example/GW150914_TaylorF2.py new file mode 100644 index 00000000..2520b794 --- /dev/null +++ b/example/GW150914_TaylorF2.py @@ -0,0 +1,155 @@ +import psutil +p = psutil.Process() +p.cpu_affinity([0]) +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "2" +os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.10" + +import time + +import jax +import jax.numpy as jnp +import optax + +from jimgw.jim import Jim +from jimgw.prior import Composite, Unconstrained_Uniform +from jimgw.single_event.detector import H1, L1 +from jimgw.single_event.likelihood import TransientLikelihoodFD +from jimgw.single_event.waveform import RippleTaylorF2 +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +total_time_start = time.time() + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 +duration = 4 +post_trigger_duration = 2 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration +fmin = 20.0 +fmax = 1024.0 + +ifos = ["H1", "L1"] + +H1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) +L1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) + +Mc_prior = Unconstrained_Uniform(10.0, 80.0, naming=["M_c"]) +q_prior = Unconstrained_Uniform( + 0.125, + 1.0, + naming=["q"], + transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)}, +) +s1z_prior = Unconstrained_Uniform(-1.0, 1.0, naming=["s1_z"]) +s2z_prior = Unconstrained_Uniform(-1.0, 1.0, naming=["s2_z"]) +lambda1_prior = Unconstrained_Uniform(0.0, 5000.0, naming=["lambda_1"]) +lambda2_prior = Unconstrained_Uniform(0.0, 5000.0, naming=["lambda_2"]) +dL_prior = Unconstrained_Uniform(0.0, 2000.0, naming=["d_L"]) +t_c_prior = Unconstrained_Uniform(-0.05, 0.05, naming=["t_c"]) +phase_c_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["phase_c"]) +cos_iota_prior = Unconstrained_Uniform( + -1.0, + 1.0, + naming=["cos_iota"], + transforms={ + "cos_iota": ( + "iota", + lambda params: jnp.arccos( + jnp.arcsin(jnp.sin(params["cos_iota"] / 2 * jnp.pi)) * 2 / jnp.pi + ), + ) + }, +) +psi_prior = Unconstrained_Uniform(0.0, jnp.pi, naming=["psi"]) +ra_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["ra"]) +sin_dec_prior = Unconstrained_Uniform( + -1.0, + 1.0, + naming=["sin_dec"], + transforms={ + "sin_dec": ( + "dec", + lambda params: jnp.arcsin( + jnp.arcsin(jnp.sin(params["sin_dec"] / 2 * jnp.pi)) * 2 / jnp.pi + ), + ) + }, +) + +prior = Composite( + [ + Mc_prior, + q_prior, + s1z_prior, + s2z_prior, + lambda1_prior, + lambda2_prior, + dL_prior, + t_c_prior, + phase_c_prior, + cos_iota_prior, + psi_prior, + ra_prior, + sin_dec_prior, + ] +) +likelihood = TransientLikelihoodFD( + [H1, L1], + waveform=RippleTaylorF2(), + trigger_time=gps, + duration=4, + post_trigger_duration=2, +) + +n_dim = 13 +mass_matrix = jnp.eye(n_dim) +mass_matrix = mass_matrix.at[0,0].set(1e-5) +mass_matrix = mass_matrix.at[1,1].set(1e-4) +mass_matrix = mass_matrix.at[2,2].set(1e-3) +mass_matrix = mass_matrix.at[3,3].set(1e-3) +mass_matrix = mass_matrix.at[7,7].set(1e-5) +mass_matrix = mass_matrix.at[11,11].set(1e-2) +mass_matrix = mass_matrix.at[12,12].set(1e-2) +local_sampler_arg = {"step_size": mass_matrix * 1e-3} + +# Build the learning rate scheduler + +n_loop_training = 100 +n_epochs = 100 +total_epochs = n_epochs * n_loop_training +start = int(total_epochs / 10) +start_lr = 1e-3 +end_lr = 1e-5 +power = 4.0 +schedule_fn = optax.polynomial_schedule( + start_lr, end_lr, power, total_epochs-start, transition_begin=start) + +jim = Jim( + likelihood, + prior, + n_loop_training=n_loop_training, + n_loop_production=20, + n_local_steps=10, + n_global_steps=1000, + n_chains=1000, + n_epochs=n_epochs, + learning_rate=schedule_fn, + n_max_examples=30000, + n_flow_samples=100000, + momentum=0.9, + batch_size=30000, + use_global=True, + train_thinning=20, + output_thinning=50, + local_sampler_arg=local_sampler_arg, +) + +jim.sample(jax.random.PRNGKey(24)) +jim.print_summary() diff --git a/example/GW170817_TaylorF2.py b/example/GW170817_TaylorF2.py new file mode 100644 index 00000000..7bd62d4a --- /dev/null +++ b/example/GW170817_TaylorF2.py @@ -0,0 +1,201 @@ +import psutil +p = psutil.Process() +p.cpu_affinity([0]) +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "3" +os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.10" +from jimgw.jim import Jim +from jimgw.single_event.detector import H1, L1, V1 +from jimgw.single_event.likelihood import HeterodynedTransientLikelihoodFD +from jimgw.single_event.waveform import RippleTaylorF2 +from jimgw.prior import Uniform, Composite +import jax.numpy as jnp +import jax +import time +jax.config.update("jax_enable_x64", True) +import numpy as np +import optax +from gwosc.datasets import event_gps +print(f"GPU found? {jax.devices()}") + + +data_path = "/home/thibeau.wouters/gw-datasets/GW170817/" # on CIT + +start_runtime = time.time() + +############ +### BODY ### +############ + +### Data definitions + +total_time_start = time.time() +gps = 1187008882.43 +trigger_time = gps +fmin = 20 +fmax = 2048 +minimum_frequency = fmin +maximum_frequency = fmax +duration = 128 +# epoch = duration - post_trigger_duration +post_trigger_duration = 32 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration +f_ref = fmin +tukey_alpha = 2 / (duration / 2) + +ifos = ["H1", "L1"]#, "V1"] + +H1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=4*duration, tukey_alpha=tukey_alpha, gwpy_kwargs={"version": 2, "cache": False}) +L1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=4*duration, tukey_alpha=tukey_alpha, gwpy_kwargs={"version": 2, "cache": False}) +# V1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.05) + +### Define priors + +# Internal parameters +Mc_prior = Uniform(1.18, 1.21, naming=["M_c"]) +q_prior = Uniform( + 0.125, + 1.0, + naming=["q"], + transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)}, +) +s1z_prior = Uniform(-0.05, 0.05, naming=["s1_z"]) +s2z_prior = Uniform(-0.05, 0.05, naming=["s2_z"]) +lambda_1_prior = Uniform(0.0, 5000.0, naming=["lambda_1"]) +lambda_2_prior = Uniform(0.0, 5000.0, naming=["lambda_2"]) +dL_prior = Uniform(1.0, 75.0, naming=["d_L"]) +# dL_prior = PowerLaw(1.0, 75.0, 2.0, naming=["d_L"]) +t_c_prior = Uniform(-0.1, 0.1, naming=["t_c"]) +phase_c_prior = Uniform(0.0, 2 * jnp.pi, naming=["phase_c"]) +cos_iota_prior = Uniform( + -1.0, + 1.0, + naming=["cos_iota"], + transforms={ + "cos_iota": ( + "iota", + lambda params: jnp.arccos( + jnp.arcsin(jnp.sin(params["cos_iota"] / 2 * jnp.pi)) * 2 / jnp.pi + ), + ) + }, +) +psi_prior = Uniform(0.0, jnp.pi, naming=["psi"]) +ra_prior = Uniform(0.0, 2 * jnp.pi, naming=["ra"]) +sin_dec_prior = Uniform( + -1.0, + 1.0, + naming=["sin_dec"], + transforms={ + "sin_dec": ( + "dec", + lambda params: jnp.arcsin( + jnp.arcsin(jnp.sin(params["sin_dec"] / 2 * jnp.pi)) * 2 / jnp.pi + ), + ) + }, +) + +prior_list = [ + Mc_prior, + q_prior, + s1z_prior, + s2z_prior, + lambda_1_prior, + lambda_2_prior, + dL_prior, + t_c_prior, + phase_c_prior, + cos_iota_prior, + psi_prior, + ra_prior, + sin_dec_prior, + ] + +prior = Composite(prior_list) + +# The following only works if every prior has xmin and xmax property, which is OK for Uniform and Powerlaw +bounds = jnp.array([[p.xmin, p.xmax] for p in prior.priors]) + +### Create likelihood object + +ref_params = { + 'M_c': 1.19793583, + 'eta': 0.24794374, + 's1_z': 0.00220637, + 's2_z': 0.0499, + 'lambda_1': 605.12916663, + 'lambda_2': 405.12916663, + 'd_L': 45.41592353, + 't_c': 0.00220588, + 'phase_c': 5.76822606, + 'iota': 2.46158044, + 'psi': 2.09118099, + 'ra': 5.03335133, + 'dec': 0.01679998 +} + +n_bins = 100 + +likelihood = HeterodynedTransientLikelihoodFD([H1, L1], prior=prior, bounds=bounds, waveform=RippleTaylorF2(f_ref=f_ref), trigger_time=gps, duration=duration, n_bins=n_bins, ref_params=ref_params) +print("Running with n_bins = ", n_bins) + +# Local sampler args + +eps = 1e-3 +n_dim = 13 +mass_matrix = jnp.eye(n_dim) +mass_matrix = mass_matrix.at[0,0].set(1e-5) +mass_matrix = mass_matrix.at[1,1].set(1e-4) +mass_matrix = mass_matrix.at[2,2].set(1e-3) +mass_matrix = mass_matrix.at[3,3].set(1e-3) +mass_matrix = mass_matrix.at[7,7].set(1e-5) +mass_matrix = mass_matrix.at[11,11].set(1e-2) +mass_matrix = mass_matrix.at[12,12].set(1e-2) +local_sampler_arg = {"step_size": mass_matrix * eps} + +# Build the learning rate scheduler + +n_loop_training = 200 +n_epochs = 50 +total_epochs = n_epochs * n_loop_training +start = int(total_epochs / 10) +start_lr = 1e-3 +end_lr = 1e-5 +power = 4.0 +schedule_fn = optax.polynomial_schedule( + start_lr, end_lr, power, total_epochs-start, transition_begin=start) + +scheduler_str = f"polynomial_schedule({start_lr}, {end_lr}, {power}, {total_epochs-start}, {start})" + +# Create jim object + +outdir_name = "./outdir/" + +jim = Jim( + likelihood, + prior, + n_loop_training=n_loop_training, + n_loop_production=20, + n_local_steps=10, + n_global_steps=500, + n_chains=1000, + n_epochs=n_epochs, + learning_rate=schedule_fn, + max_samples=50000, + momentum=0.9, + batch_size=50000, + use_global=True, + keep_quantile=0.0, + train_thinning=10, + output_thinning=30, + local_sampler_arg=local_sampler_arg, + stopping_criterion_global_acc = 0.20, + outdir_name=outdir_name +) + +### Sample and show results + +jim.sample(jax.random.PRNGKey(41)) +jim.print_summary() \ No newline at end of file diff --git a/src/jimgw/single_event/waveform.py b/src/jimgw/single_event/waveform.py index 2434e836..6be70ef4 100644 --- a/src/jimgw/single_event/waveform.py +++ b/src/jimgw/single_event/waveform.py @@ -4,7 +4,8 @@ from jaxtyping import Array, Float from ripple.waveforms.IMRPhenomD import gen_IMRPhenomD_hphc from ripple.waveforms.IMRPhenomPv2 import gen_IMRPhenomPv2_hphc - +from ripple.waveforms.TaylorF2 import gen_TaylorF2_hphc +from ripple.waveforms.IMRPhenomD_NRTidalv2 import gen_IMRPhenomD_NRTidalv2_hphc class Waveform(ABC): def __init__(self): @@ -81,8 +82,94 @@ def __call__( def __repr__(self): return f"RippleIMRPhenomPv2(f_ref={self.f_ref})" +class RippleTaylorF2(Waveform): + + f_ref: float + use_lambda_tildes: bool + + def __init__(self, f_ref: float = 20.0, use_lambda_tildes: bool = False): + self.f_ref = f_ref + self.use_lambda_tildes = use_lambda_tildes + + def __call__(self, frequency: Array, params: dict) -> dict: + output = {} + ra = params["ra"] + dec = params["dec"] + + if self.use_lambda_tildes: + first_lambda_param = params["lambda_tilde"] + second_lambda_param = params["delta_lambda_tilde"] + else: + first_lambda_param = params["lambda_1"] + second_lambda_param = params["lambda_2"] + + theta = [ + params["M_c"], + params["eta"], + params["s1_z"], + params["s2_z"], + first_lambda_param, + second_lambda_param, + params["d_L"], + 0, + params["phase_c"], + params["iota"], + ] + hp, hc = gen_TaylorF2_hphc(frequency, theta, self.f_ref, use_lambda_tildes=self.use_lambda_tildes) + output["p"] = hp + output["c"] = hc + return output + + def __repr__(self): + return f"RippleTaylorF2(f_ref={self.f_ref})" + +class RippleIMRPhenomD_NRTidalv2(Waveform): + + f_ref: float + use_lambda_tildes: bool + + def __init__(self, f_ref: float = 20.0, use_lambda_tildes: bool = False): + self.f_ref = f_ref + self.use_lambda_tildes = use_lambda_tildes + + def __call__(self, frequency: Array, params: dict) -> dict: + output = {} + ra = params["ra"] + dec = params["dec"] + + if self.use_lambda_tildes: + first_lambda_param = params["lambda_tilde"] + second_lambda_param = params["delta_lambda_tilde"] + else: + first_lambda_param = params["lambda_1"] + second_lambda_param = params["lambda_2"] + + theta = [ + params["M_c"], + params["eta"], + params["s1_z"], + params["s2_z"], + first_lambda_param, + second_lambda_param, + params["d_L"], + 0, + params["phase_c"], + params["iota"], + ] + + hp, hc = gen_IMRPhenomD_NRTidalv2_hphc(frequency, theta, self.f_ref, use_lambda_tildes=self.use_lambda_tildes) + output["p"] = hp + output["c"] = hc + return output + + def __repr__(self): + return f"RippleIMRPhenomD_NRTidalv2(f_ref={self.f_ref})" + + waveform_preset = { "RippleIMRPhenomD": RippleIMRPhenomD, "RippleIMRPhenomPv2": RippleIMRPhenomPv2, + "RippleTaylorF2": RippleTaylorF2, + "RippleIMRPhenomD_NRTidalv2": RippleIMRPhenomD_NRTidalv2, } From a66911f54230c98a79a91f92ef4e4162c08deec5 Mon Sep 17 00:00:00 2001 From: Thibeau Wouters Date: Thu, 23 May 2024 01:21:33 -0700 Subject: [PATCH 2/8] added PhenomD_NRTv2 and example scripts; reproducing previous results not done yet --- .gitignore | 3 + example/.gitignore | 2 + example/GW170817_PhenomD_NRTv2.py | 298 +++++++++++++++++++++++++++ example/GW170817_TaylorF2.py | 153 +++++++++++--- src/jimgw/single_event/likelihood.py | 11 +- src/jimgw/single_event/waveform.py | 15 +- 6 files changed, 445 insertions(+), 37 deletions(-) create mode 100644 example/.gitignore create mode 100644 example/GW170817_PhenomD_NRTv2.py diff --git a/.gitignore b/.gitignore index 5d6606d3..499ad55f 100644 --- a/.gitignore +++ b/.gitignore @@ -140,3 +140,6 @@ H1.txt L1.txt V1.txt test_data + +# Out directory of runs +outdir diff --git a/example/.gitignore b/example/.gitignore new file mode 100644 index 00000000..d3666e0b --- /dev/null +++ b/example/.gitignore @@ -0,0 +1,2 @@ +utils_plotting.py +outdir*/ diff --git a/example/GW170817_PhenomD_NRTv2.py b/example/GW170817_PhenomD_NRTv2.py new file mode 100644 index 00000000..6d1d6adb --- /dev/null +++ b/example/GW170817_PhenomD_NRTv2.py @@ -0,0 +1,298 @@ +import psutil +p = psutil.Process() +p.cpu_affinity([0]) +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "3" +os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.10" +from jimgw.jim import Jim +from jimgw.single_event.detector import H1, L1, V1 +from jimgw.single_event.likelihood import HeterodynedTransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomD_NRTidalv2 +from jimgw.prior import Uniform, Composite +import jax.numpy as jnp +import jax +import time +import numpy as np +jax.config.update("jax_enable_x64", True) +import shutil +import numpy as np +import matplotlib.pyplot as plt +import optax +print(f"Devices found by Jax: {jax.devices()}") + +import utils_plotting as utils + +################ +### PREAMBLE ### +################ + +data_path = "/home/thibeau.wouters/gw-datasets/GW170817/" # this is on the CIT cluster # TODO: has to be shared externally! + +start_runtime = time.time() + +############ +### BODY ### +############ + +### Data definitions + +gps = 1187008882.43 +trigger_time = gps +fmin = 20 +fmax = 2048 +minimum_frequency = fmin +maximum_frequency = fmax +duration = 128 +post_trigger_duration = 2 +epoch = duration - post_trigger_duration +f_ref = fmin + +### Getting detector data + +# This is our preprocessed data obtained from the TXT files at the GWOSC website (the GWF gave me NaNs?) +H1.frequencies = jnp.array(np.genfromtxt(f'{data_path}H1_freq.txt')) +H1_data_re, H1_data_im = jnp.array(np.genfromtxt(f'{data_path}H1_data_re.txt')), jnp.array(np.genfromtxt(f'{data_path}H1_data_im.txt')) +H1.data = H1_data_re + 1j * H1_data_im + +L1.frequencies = jnp.array(np.genfromtxt(f'{data_path}L1_freq.txt')) +L1_data_re, L1_data_im = jnp.array(np.genfromtxt(f'{data_path}L1_data_re.txt')), jnp.array(np.genfromtxt(f'{data_path}L1_data_im.txt')) +L1.data = L1_data_re + 1j * L1_data_im + +V1.frequencies = jnp.array(np.genfromtxt(f'{data_path}V1_freq.txt')) +V1_data_re, V1_data_im = jnp.array(np.genfromtxt(f'{data_path}V1_data_re.txt')), jnp.array(np.genfromtxt(f'{data_path}V1_data_im.txt')) +V1.data = V1_data_re + 1j * V1_data_im + +# Load the PSD + +H1.psd = H1.load_psd(H1.frequencies, psd_file = data_path + "GW170817-IMRD_data0_1187008882-43_generation_data_dump.pickle_H1_psd.txt") +L1.psd = L1.load_psd(L1.frequencies, psd_file = data_path + "GW170817-IMRD_data0_1187008882-43_generation_data_dump.pickle_L1_psd.txt") +V1.psd = V1.load_psd(V1.frequencies, psd_file = data_path + "GW170817-IMRD_data0_1187008882-43_generation_data_dump.pickle_V1_psd.txt") + +### Define priors + +# Internal parameters +Mc_prior = Uniform(1.18, 1.21, naming=["M_c"]) +q_prior = Uniform( + 0.125, + 1.0, + naming=["q"], + transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)}, +) +s1z_prior = Uniform(-0.05, 0.05, naming=["s1_z"]) +s2z_prior = Uniform(-0.05, 0.05, naming=["s2_z"]) +lambda_1_prior = Uniform(0.0, 5000.0, naming=["lambda_1"]) +lambda_2_prior = Uniform(0.0, 5000.0, naming=["lambda_2"]) +dL_prior = Uniform(1.0, 75.0, naming=["d_L"]) +t_c_prior = Uniform(-0.1, 0.1, naming=["t_c"]) +phase_c_prior = Uniform(0.0, 2 * jnp.pi, naming=["phase_c"]) +cos_iota_prior = Uniform( + -1.0, + 1.0, + naming=["cos_iota"], + transforms={ + "cos_iota": ( + "iota", + lambda params: jnp.arccos( + jnp.arcsin(jnp.sin(params["cos_iota"] / 2 * jnp.pi)) * 2 / jnp.pi + ), + ) + }, +) +psi_prior = Uniform(0.0, jnp.pi, naming=["psi"]) +ra_prior = Uniform(0.0, 2 * jnp.pi, naming=["ra"]) +sin_dec_prior = Uniform( + -1.0, + 1.0, + naming=["sin_dec"], + transforms={ + "sin_dec": ( + "dec", + lambda params: jnp.arcsin( + jnp.arcsin(jnp.sin(params["sin_dec"] / 2 * jnp.pi)) * 2 / jnp.pi + ), + ) + }, +) + +prior_list = [ + Mc_prior, + q_prior, + s1z_prior, + s2z_prior, + lambda_1_prior, + lambda_2_prior, + dL_prior, + t_c_prior, + phase_c_prior, + cos_iota_prior, + psi_prior, + ra_prior, + sin_dec_prior, + ] + +prior = Composite(prior_list) + +# The following only works if every prior has xmin and xmax property, which is OK for Uniform and Powerlaw +bounds = jnp.array([[p.xmin, p.xmax] for p in prior.priors]) + +### Create likelihood object + +# For simplicity, we put here a set of reference parameters found by the optimizer +ref_params = { + 'M_c': 1.1975896, + 'eta': 0.2461001, + 's1_z': -0.01890608, + 's2_z': 0.04888488, + 'lambda_1': 791.04366468, + 'lambda_2': 891.04366468, + 'd_L': 16.06331818, + 't_c': 0.00193536, + 'phase_c': 5.88649652, + 'iota': 1.93095421, + 'psi': 1.59687217, + 'ra': 3.39736826, + 'dec': -0.34000186 +} + +# Number of bins to use for relative binning +n_bins = 500 + +waveform = RippleIMRPhenomD_NRTidalv2(f_ref=f_ref) +reference_waveform = RippleIMRPhenomD_NRTidalv2(f_ref=f_ref, no_taper=True) + +likelihood = HeterodynedTransientLikelihoodFD([H1, L1, V1], + prior=prior, + bounds=bounds, + waveform=waveform, + trigger_time=gps, + duration=duration, + n_bins=n_bins, + ref_params=ref_params, + reference_waveform=reference_waveform) + +# Local sampler args + +eps = 1e-3 +n_dim = 13 +mass_matrix = jnp.eye(n_dim) +mass_matrix = mass_matrix.at[0,0].set(1e-5) +mass_matrix = mass_matrix.at[1,1].set(1e-4) +mass_matrix = mass_matrix.at[2,2].set(1e-3) +mass_matrix = mass_matrix.at[3,3].set(1e-3) +mass_matrix = mass_matrix.at[7,7].set(1e-5) +mass_matrix = mass_matrix.at[11,11].set(1e-2) +mass_matrix = mass_matrix.at[12,12].set(1e-2) +local_sampler_arg = {"step_size": mass_matrix * eps} + +# Build the learning rate scheduler (if used) + +n_loop_training = 300 +n_epochs = 50 +total_epochs = n_epochs * n_loop_training +start = int(total_epochs / 10) +start_lr = 1e-3 +end_lr = 1e-5 +power = 4.0 +schedule_fn = optax.polynomial_schedule( + start_lr, end_lr, power, total_epochs-start, transition_begin=start) + +scheduler_str = f"polynomial_schedule({start_lr}, {end_lr}, {power}, {total_epochs-start}, {start})" + +## Choose between fixed learning rate - or - the above scheduler here +# learning_rate = schedule_fn +learning_rate = 0.001 + +print(f"Learning rate: {learning_rate}") + +# Create jim object + +outdir_name = "./outdir/" + +jim = Jim( + likelihood, + prior, + n_loop_training=n_loop_training, + n_loop_production=20, + n_local_steps=100, + n_global_steps=1000, + n_chains=1000, + n_epochs=n_epochs, + learning_rate=schedule_fn, + max_samples=50000, + momentum=0.9, + batch_size=50000, + use_global=True, + keep_quantile=0.0, + train_thinning=10, + output_thinning=30, + local_sampler_arg=local_sampler_arg, + outdir_name=outdir_name +) + +### Heavy computation begins +jim.sample(jax.random.PRNGKey(41)) +### Heavy computation ends + +# === Show results, save output === + +# Print a summary to screen: +jim.print_summary() +outdir = outdir_name + +# Save and plot the results of the run +# - training phase + +name = outdir + f'results_training.npz' +print(f"Saving samples to {name}") +state = jim.Sampler.get_sampler_state(training=True) +chains, log_prob, local_accs, global_accs, loss_vals = state["chains"], state[ + "log_prob"], state["local_accs"], state["global_accs"], state["loss_vals"] +local_accs = jnp.mean(local_accs, axis=0) +global_accs = jnp.mean(global_accs, axis=0) +np.savez(name, log_prob=log_prob, local_accs=local_accs, + global_accs=global_accs, loss_vals=loss_vals) + +utils.plot_accs(local_accs, "Local accs (training)", + "local_accs_training", outdir) +utils.plot_accs(global_accs, "Global accs (training)", + "global_accs_training", outdir) +utils.plot_loss_vals(loss_vals, "Loss", "loss_vals", outdir) +utils.plot_log_prob(log_prob, "Log probability (training)", + "log_prob_training", outdir) + +# - production phase +name = outdir + f'results_production.npz' +state = jim.Sampler.get_sampler_state(training=False) +chains, log_prob, local_accs, global_accs = state["chains"], state[ + "log_prob"], state["local_accs"], state["global_accs"] +local_accs = jnp.mean(local_accs, axis=0) +global_accs = jnp.mean(global_accs, axis=0) +np.savez(name, chains=chains, log_prob=log_prob, + local_accs=local_accs, global_accs=global_accs) + +utils.plot_accs(local_accs, "Local accs (production)", + "local_accs_production", outdir) +utils.plot_accs(global_accs, "Global accs (production)", + "global_accs_production", outdir) +utils.plot_log_prob(log_prob, "Log probability (production)", + "log_prob_production", outdir) + +# Plot the chains as corner plots +utils.plot_chains(chains, "chains_production", outdir, truths=None) + +# Save the NF and show a plot of samples from the flow +print("Saving the NF") +jim.Sampler.save_flow(outdir + "nf_model") + +# Final steps + + +print("Finished successfully") + +end_runtime = time.time() +runtime = end_runtime - start_runtime +print(f"Time taken: {runtime} seconds ({(runtime)/60} minutes)") + +print(f"Saving runtime") +with open(outdir + 'runtime.txt', 'w') as file: + file.write(str(runtime)) \ No newline at end of file diff --git a/example/GW170817_TaylorF2.py b/example/GW170817_TaylorF2.py index 7bd62d4a..554e0679 100644 --- a/example/GW170817_TaylorF2.py +++ b/example/GW170817_TaylorF2.py @@ -12,14 +12,21 @@ import jax.numpy as jnp import jax import time +import numpy as np jax.config.update("jax_enable_x64", True) +import shutil import numpy as np +import matplotlib.pyplot as plt import optax -from gwosc.datasets import event_gps -print(f"GPU found? {jax.devices()}") +print(f"Devices found by Jax: {jax.devices()}") + +import utils_plotting as utils +################ +### PREAMBLE ### +################ -data_path = "/home/thibeau.wouters/gw-datasets/GW170817/" # on CIT +data_path = "/home/thibeau.wouters/gw-datasets/GW170817/" # this is on the CIT cluster # TODO: has to be shared externally! start_runtime = time.time() @@ -29,7 +36,6 @@ ### Data definitions -total_time_start = time.time() gps = 1187008882.43 trigger_time = gps fmin = 20 @@ -37,18 +43,30 @@ minimum_frequency = fmin maximum_frequency = fmax duration = 128 -# epoch = duration - post_trigger_duration -post_trigger_duration = 32 -start_pad = duration - post_trigger_duration -end_pad = post_trigger_duration +post_trigger_duration = 2 +epoch = duration - post_trigger_duration f_ref = fmin -tukey_alpha = 2 / (duration / 2) -ifos = ["H1", "L1"]#, "V1"] +### Getting detector data + +# This is our preprocessed data obtained from the TXT files at the GWOSC website (the GWF gave me NaNs?) +H1.frequencies = jnp.array(np.genfromtxt(f'{data_path}H1_freq.txt')) +H1_data_re, H1_data_im = jnp.array(np.genfromtxt(f'{data_path}H1_data_re.txt')), jnp.array(np.genfromtxt(f'{data_path}H1_data_im.txt')) +H1.data = H1_data_re + 1j * H1_data_im + +L1.frequencies = jnp.array(np.genfromtxt(f'{data_path}L1_freq.txt')) +L1_data_re, L1_data_im = jnp.array(np.genfromtxt(f'{data_path}L1_data_re.txt')), jnp.array(np.genfromtxt(f'{data_path}L1_data_im.txt')) +L1.data = L1_data_re + 1j * L1_data_im + +V1.frequencies = jnp.array(np.genfromtxt(f'{data_path}V1_freq.txt')) +V1_data_re, V1_data_im = jnp.array(np.genfromtxt(f'{data_path}V1_data_re.txt')), jnp.array(np.genfromtxt(f'{data_path}V1_data_im.txt')) +V1.data = V1_data_re + 1j * V1_data_im + +# Load the PSD -H1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=4*duration, tukey_alpha=tukey_alpha, gwpy_kwargs={"version": 2, "cache": False}) -L1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=4*duration, tukey_alpha=tukey_alpha, gwpy_kwargs={"version": 2, "cache": False}) -# V1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.05) +H1.psd = H1.load_psd(H1.frequencies, psd_file = data_path + "GW170817-IMRD_data0_1187008882-43_generation_data_dump.pickle_H1_psd.txt") +L1.psd = L1.load_psd(L1.frequencies, psd_file = data_path + "GW170817-IMRD_data0_1187008882-43_generation_data_dump.pickle_L1_psd.txt") +V1.psd = V1.load_psd(V1.frequencies, psd_file = data_path + "GW170817-IMRD_data0_1187008882-43_generation_data_dump.pickle_V1_psd.txt") ### Define priors @@ -65,7 +83,6 @@ lambda_1_prior = Uniform(0.0, 5000.0, naming=["lambda_1"]) lambda_2_prior = Uniform(0.0, 5000.0, naming=["lambda_2"]) dL_prior = Uniform(1.0, 75.0, naming=["d_L"]) -# dL_prior = PowerLaw(1.0, 75.0, 2.0, naming=["d_L"]) t_c_prior = Uniform(-0.1, 0.1, naming=["t_c"]) phase_c_prior = Uniform(0.0, 2 * jnp.pi, naming=["phase_c"]) cos_iota_prior = Uniform( @@ -120,13 +137,14 @@ ### Create likelihood object +# For simplicity, we put here a set of reference parameters found by the optimizer ref_params = { 'M_c': 1.19793583, 'eta': 0.24794374, 's1_z': 0.00220637, - 's2_z': 0.0499, - 'lambda_1': 605.12916663, - 'lambda_2': 405.12916663, + 's2_z': 0.05, + 'lambda_1': 105.12916663, + 'lambda_2': 0.0, 'd_L': 45.41592353, 't_c': 0.00220588, 'phase_c': 5.76822606, @@ -136,10 +154,18 @@ 'dec': 0.01679998 } -n_bins = 100 +# Number of bins to use for relative binning +n_bins = 500 -likelihood = HeterodynedTransientLikelihoodFD([H1, L1], prior=prior, bounds=bounds, waveform=RippleTaylorF2(f_ref=f_ref), trigger_time=gps, duration=duration, n_bins=n_bins, ref_params=ref_params) -print("Running with n_bins = ", n_bins) +waveform = RippleTaylorF2(f_ref=f_ref) +likelihood = HeterodynedTransientLikelihoodFD([H1, L1, V1], + prior=prior, + bounds=bounds, + waveform=waveform, + trigger_time=gps, + duration=duration, + n_bins=n_bins, + ref_params=ref_params) # Local sampler args @@ -155,9 +181,9 @@ mass_matrix = mass_matrix.at[12,12].set(1e-2) local_sampler_arg = {"step_size": mass_matrix * eps} -# Build the learning rate scheduler +# Build the learning rate scheduler (if used) -n_loop_training = 200 +n_loop_training = 300 n_epochs = 50 total_epochs = n_epochs * n_loop_training start = int(total_epochs / 10) @@ -169,17 +195,22 @@ scheduler_str = f"polynomial_schedule({start_lr}, {end_lr}, {power}, {total_epochs-start}, {start})" -# Create jim object +## Choose between fixed learning rate - or - the above scheduler here +# learning_rate = schedule_fn +learning_rate = 0.001 + +print(f"Learning rate: {learning_rate}") -outdir_name = "./outdir/" +# Create jim object +outdir_name = "./outdir_TF2/" jim = Jim( likelihood, prior, n_loop_training=n_loop_training, n_loop_production=20, - n_local_steps=10, - n_global_steps=500, + n_local_steps=100, + n_global_steps=1000, n_chains=1000, n_epochs=n_epochs, learning_rate=schedule_fn, @@ -191,11 +222,73 @@ train_thinning=10, output_thinning=30, local_sampler_arg=local_sampler_arg, - stopping_criterion_global_acc = 0.20, outdir_name=outdir_name ) -### Sample and show results +### Heavy computation begins +jim.sample(jax.random.PRNGKey(43)) +### Heavy computation ends + +# === Show results, save output === + +# Print a summary to screen: +jim.print_summary() +outdir = outdir_name + +# Save and plot the results of the run +# - training phase + +name = outdir + f'results_training.npz' +print(f"Saving samples to {name}") +state = jim.Sampler.get_sampler_state(training=True) +chains, log_prob, local_accs, global_accs, loss_vals = state["chains"], state[ + "log_prob"], state["local_accs"], state["global_accs"], state["loss_vals"] +local_accs = jnp.mean(local_accs, axis=0) +global_accs = jnp.mean(global_accs, axis=0) +np.savez(name, log_prob=log_prob, local_accs=local_accs, + global_accs=global_accs, loss_vals=loss_vals) + +utils.plot_accs(local_accs, "Local accs (training)", + "local_accs_training", outdir) +utils.plot_accs(global_accs, "Global accs (training)", + "global_accs_training", outdir) +utils.plot_loss_vals(loss_vals, "Loss", "loss_vals", outdir) +utils.plot_log_prob(log_prob, "Log probability (training)", + "log_prob_training", outdir) + +# - production phase +name = outdir + f'results_production.npz' +state = jim.Sampler.get_sampler_state(training=False) +chains, log_prob, local_accs, global_accs = state["chains"], state[ + "log_prob"], state["local_accs"], state["global_accs"] +local_accs = jnp.mean(local_accs, axis=0) +global_accs = jnp.mean(global_accs, axis=0) +np.savez(name, chains=chains, log_prob=log_prob, + local_accs=local_accs, global_accs=global_accs) + +utils.plot_accs(local_accs, "Local accs (production)", + "local_accs_production", outdir) +utils.plot_accs(global_accs, "Global accs (production)", + "global_accs_production", outdir) +utils.plot_log_prob(log_prob, "Log probability (production)", + "log_prob_production", outdir) + +# Plot the chains as corner plots +utils.plot_chains(chains, "chains_production", outdir, truths=None) + +# Save the NF and show a plot of samples from the flow +print("Saving the NF") +jim.Sampler.save_flow(outdir + "nf_model") + +# Final steps + + +print("Finished successfully") + +end_runtime = time.time() +runtime = end_runtime - start_runtime +print(f"Time taken: {runtime} seconds ({(runtime)/60} minutes)") -jim.sample(jax.random.PRNGKey(41)) -jim.print_summary() \ No newline at end of file +print(f"Saving runtime") +with open(outdir + 'runtime.txt', 'w') as file: + file.write(str(runtime)) \ No newline at end of file diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index 75d052ba..d6347f9e 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -192,6 +192,7 @@ def __init__( popsize: int = 100, n_steps: int = 2000, ref_params: dict = {}, + reference_waveform: Waveform = None, **kwargs, ) -> None: super().__init__( @@ -199,6 +200,10 @@ def __init__( ) print("Initializing heterodyned likelihood..") + + # Can use another waveform to use as reference waveform, but if not provided, use the same waveform + if reference_waveform is None: + reference_waveform = waveform self.kwargs = kwargs if "marginalization" in self.kwargs: @@ -281,7 +286,7 @@ def __init__( self.B0_array = {} self.B1_array = {} - h_sky = self.waveform(frequency_original, self.ref_params) + h_sky = reference_waveform(frequency_original, self.ref_params) # Get frequency masks to be applied, for both original # and heterodyne frequency grid @@ -307,8 +312,8 @@ def __init__( if len(self.freq_grid_low) > len(self.freq_grid_center): self.freq_grid_low = self.freq_grid_low[: len(self.freq_grid_center)] - h_sky_low = self.waveform(self.freq_grid_low, self.ref_params) - h_sky_center = self.waveform(self.freq_grid_center, self.ref_params) + h_sky_low = reference_waveform(self.freq_grid_low, self.ref_params) + h_sky_center = reference_waveform(self.freq_grid_center, self.ref_params) # Get phase shifts to align time of coalescence align_time = jnp.exp( diff --git a/src/jimgw/single_event/waveform.py b/src/jimgw/single_event/waveform.py index 6be70ef4..2f49fd8c 100644 --- a/src/jimgw/single_event/waveform.py +++ b/src/jimgw/single_event/waveform.py @@ -128,9 +128,18 @@ class RippleIMRPhenomD_NRTidalv2(Waveform): f_ref: float use_lambda_tildes: bool - def __init__(self, f_ref: float = 20.0, use_lambda_tildes: bool = False): + def __init__(self, f_ref: float = 20.0, use_lambda_tildes: bool = False, no_taper: bool = False): + """ + Initialize the waveform. + + Args: + f_ref (float, optional): Reference frequency in Hz. Defaults to 20.0. + use_lambda_tildes (bool, optional): Whether we sample over lambda_tilde and delta_lambda_tilde, as defined for instance in Equation (5) and Equation (6) of arXiv:1402.5156, rather than lambda_1 and lambda_2. Defaults to False. + no_taper (bool, optional): Whether to remove the Planck taper in the amplitude of the waveform, which we use for relative binning runs. Defaults to False. + """ self.f_ref = f_ref self.use_lambda_tildes = use_lambda_tildes + self.no_taper = no_taper def __call__(self, frequency: Array, params: dict) -> dict: output = {} @@ -157,7 +166,7 @@ def __call__(self, frequency: Array, params: dict) -> dict: params["iota"], ] - hp, hc = gen_IMRPhenomD_NRTidalv2_hphc(frequency, theta, self.f_ref, use_lambda_tildes=self.use_lambda_tildes) + hp, hc = gen_IMRPhenomD_NRTidalv2_hphc(frequency, theta, self.f_ref, use_lambda_tildes=self.use_lambda_tildes, no_taper=self.no_taper) output["p"] = hp output["c"] = hc return output @@ -165,8 +174,6 @@ def __call__(self, frequency: Array, params: dict) -> dict: def __repr__(self): return f"RippleIMRPhenomD_NRTidalv2(f_ref={self.f_ref})" - - waveform_preset = { "RippleIMRPhenomD": RippleIMRPhenomD, "RippleIMRPhenomPv2": RippleIMRPhenomPv2, From 5c9986994ec0d9067c879aadfd2eed0f680f7488 Mon Sep 17 00:00:00 2001 From: Thibeau Wouters Date: Thu, 23 May 2024 01:45:05 -0700 Subject: [PATCH 3/8] removing unused ra and dec --- src/jimgw/single_event/waveform.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/jimgw/single_event/waveform.py b/src/jimgw/single_event/waveform.py index 2f49fd8c..29ba3a00 100644 --- a/src/jimgw/single_event/waveform.py +++ b/src/jimgw/single_event/waveform.py @@ -93,8 +93,6 @@ def __init__(self, f_ref: float = 20.0, use_lambda_tildes: bool = False): def __call__(self, frequency: Array, params: dict) -> dict: output = {} - ra = params["ra"] - dec = params["dec"] if self.use_lambda_tildes: first_lambda_param = params["lambda_tilde"] @@ -143,8 +141,6 @@ def __init__(self, f_ref: float = 20.0, use_lambda_tildes: bool = False, no_tape def __call__(self, frequency: Array, params: dict) -> dict: output = {} - ra = params["ra"] - dec = params["dec"] if self.use_lambda_tildes: first_lambda_param = params["lambda_tilde"] From 25dff03e1a4c97ca828fabea2b4ddd1ecd6ad0b9 Mon Sep 17 00:00:00 2001 From: Thibeau Wouters Date: Thu, 23 May 2024 02:23:54 -0700 Subject: [PATCH 4/8] fixing precommit complaints --- example/GW170817_TaylorF2.py | 2 +- src/jimgw/single_event/likelihood.py | 3 ++- src/jimgw/single_event/waveform.py | 16 ++++++++-------- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/example/GW170817_TaylorF2.py b/example/GW170817_TaylorF2.py index 554e0679..499f2d6d 100644 --- a/example/GW170817_TaylorF2.py +++ b/example/GW170817_TaylorF2.py @@ -2,7 +2,7 @@ p = psutil.Process() p.cpu_affinity([0]) import os -os.environ["CUDA_VISIBLE_DEVICES"] = "3" +os.environ["CUDA_VISIBLE_DEVICES"] = "2" os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.10" from jimgw.jim import Jim from jimgw.single_event.detector import H1, L1, V1 diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index d6347f9e..295e2694 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -6,6 +6,7 @@ from flowMC.strategy.optimization import optimization_Adam from jax.scipy.special import logsumexp from jaxtyping import Array, Float +from typing import Optional from scipy.interpolate import interp1d from jimgw.base import LikelihoodBase @@ -192,7 +193,7 @@ def __init__( popsize: int = 100, n_steps: int = 2000, ref_params: dict = {}, - reference_waveform: Waveform = None, + reference_waveform: Optional[Waveform] = None, **kwargs, ) -> None: super().__init__( diff --git a/src/jimgw/single_event/waveform.py b/src/jimgw/single_event/waveform.py index 29ba3a00..aa3ae6e9 100644 --- a/src/jimgw/single_event/waveform.py +++ b/src/jimgw/single_event/waveform.py @@ -95,11 +95,11 @@ def __call__(self, frequency: Array, params: dict) -> dict: output = {} if self.use_lambda_tildes: - first_lambda_param = params["lambda_tilde"] - second_lambda_param = params["delta_lambda_tilde"] + first_lambda_param = jnp.array(params["lambda_tilde"]) + second_lambda_param = jnp.array(params["delta_lambda_tilde"]) else: - first_lambda_param = params["lambda_1"] - second_lambda_param = params["lambda_2"] + first_lambda_param = jnp.array(params["lambda_1"]) + second_lambda_param = jnp.array(params["lambda_2"]) theta = [ params["M_c"], @@ -143,11 +143,11 @@ def __call__(self, frequency: Array, params: dict) -> dict: output = {} if self.use_lambda_tildes: - first_lambda_param = params["lambda_tilde"] - second_lambda_param = params["delta_lambda_tilde"] + first_lambda_param = jnp.array(params["lambda_tilde"]) + second_lambda_param = jnp.array(params["delta_lambda_tilde"]) else: - first_lambda_param = params["lambda_1"] - second_lambda_param = params["lambda_2"] + first_lambda_param = jnp.array(params["lambda_1"]) + second_lambda_param = jnp.array(params["lambda_2"]) theta = [ params["M_c"], From 8aef32ca54e3a096527860835eac14791e064269 Mon Sep 17 00:00:00 2001 From: Thibeau Wouters Date: Thu, 23 May 2024 02:42:40 -0700 Subject: [PATCH 5/8] more precommit fixes --- src/jimgw/single_event/likelihood.py | 2 +- src/jimgw/single_event/waveform.py | 109 ++++++++++++++++----------- 2 files changed, 66 insertions(+), 45 deletions(-) diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index 295e2694..58ffb083 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -201,7 +201,7 @@ def __init__( ) print("Initializing heterodyned likelihood..") - + # Can use another waveform to use as reference waveform, but if not provided, use the same waveform if reference_waveform is None: reference_waveform = waveform diff --git a/src/jimgw/single_event/waveform.py b/src/jimgw/single_event/waveform.py index aa3ae6e9..b5e5351c 100644 --- a/src/jimgw/single_event/waveform.py +++ b/src/jimgw/single_event/waveform.py @@ -7,6 +7,7 @@ from ripple.waveforms.TaylorF2 import gen_TaylorF2_hphc from ripple.waveforms.IMRPhenomD_NRTidalv2 import gen_IMRPhenomD_NRTidalv2_hphc + class Waveform(ABC): def __init__(self): return NotImplemented @@ -82,6 +83,7 @@ def __call__( def __repr__(self): return f"RippleIMRPhenomPv2(f_ref={self.f_ref})" + class RippleTaylorF2(Waveform): f_ref: float @@ -93,40 +95,50 @@ def __init__(self, f_ref: float = 20.0, use_lambda_tildes: bool = False): def __call__(self, frequency: Array, params: dict) -> dict: output = {} - + if self.use_lambda_tildes: - first_lambda_param = jnp.array(params["lambda_tilde"]) - second_lambda_param = jnp.array(params["delta_lambda_tilde"]) + first_lambda_param = params["lambda_tilde"] + second_lambda_param = params["delta_lambda_tilde"] else: - first_lambda_param = jnp.array(params["lambda_1"]) - second_lambda_param = jnp.array(params["lambda_2"]) - - theta = [ - params["M_c"], - params["eta"], - params["s1_z"], - params["s2_z"], - first_lambda_param, - second_lambda_param, - params["d_L"], - 0, - params["phase_c"], - params["iota"], - ] - hp, hc = gen_TaylorF2_hphc(frequency, theta, self.f_ref, use_lambda_tildes=self.use_lambda_tildes) + first_lambda_param = params["lambda_1"] + second_lambda_param = params["lambda_2"] + + theta = jnp.array( + [ + params["M_c"], + params["eta"], + params["s1_z"], + params["s2_z"], + first_lambda_param, + second_lambda_param, + params["d_L"], + 0, + params["phase_c"], + params["iota"], + ] + ) + hp, hc = gen_TaylorF2_hphc( + frequency, theta, self.f_ref, use_lambda_tildes=self.use_lambda_tildes + ) output["p"] = hp output["c"] = hc return output - + def __repr__(self): return f"RippleTaylorF2(f_ref={self.f_ref})" - + + class RippleIMRPhenomD_NRTidalv2(Waveform): f_ref: float use_lambda_tildes: bool - def __init__(self, f_ref: float = 20.0, use_lambda_tildes: bool = False, no_taper: bool = False): + def __init__( + self, + f_ref: float = 20.0, + use_lambda_tildes: bool = False, + no_taper: bool = False, + ): """ Initialize the waveform. @@ -141,35 +153,44 @@ def __init__(self, f_ref: float = 20.0, use_lambda_tildes: bool = False, no_tape def __call__(self, frequency: Array, params: dict) -> dict: output = {} - + if self.use_lambda_tildes: - first_lambda_param = jnp.array(params["lambda_tilde"]) - second_lambda_param = jnp.array(params["delta_lambda_tilde"]) + first_lambda_param = params["lambda_tilde"] + second_lambda_param = params["delta_lambda_tilde"] else: - first_lambda_param = jnp.array(params["lambda_1"]) - second_lambda_param = jnp.array(params["lambda_2"]) - - theta = [ - params["M_c"], - params["eta"], - params["s1_z"], - params["s2_z"], - first_lambda_param, - second_lambda_param, - params["d_L"], - 0, - params["phase_c"], - params["iota"], - ] - - hp, hc = gen_IMRPhenomD_NRTidalv2_hphc(frequency, theta, self.f_ref, use_lambda_tildes=self.use_lambda_tildes, no_taper=self.no_taper) + first_lambda_param = params["lambda_1"] + second_lambda_param = params["lambda_2"] + + theta = jnp.array( + [ + params["M_c"], + params["eta"], + params["s1_z"], + params["s2_z"], + first_lambda_param, + second_lambda_param, + params["d_L"], + 0, + params["phase_c"], + params["iota"], + ] + ) + + hp, hc = gen_IMRPhenomD_NRTidalv2_hphc( + frequency, + theta, + self.f_ref, + use_lambda_tildes=self.use_lambda_tildes, + no_taper=self.no_taper, + ) output["p"] = hp output["c"] = hc return output - + def __repr__(self): return f"RippleIMRPhenomD_NRTidalv2(f_ref={self.f_ref})" - + + waveform_preset = { "RippleIMRPhenomD": RippleIMRPhenomD, "RippleIMRPhenomPv2": RippleIMRPhenomPv2, From c77b8a69f3c242cac2b75e94a982b1acf35c04ea Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Sun, 26 May 2024 09:26:41 -0400 Subject: [PATCH 6/8] Update waveform.py Update typing information --- src/jimgw/single_event/waveform.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/jimgw/single_event/waveform.py b/src/jimgw/single_event/waveform.py index b5e5351c..adc2f5b7 100644 --- a/src/jimgw/single_event/waveform.py +++ b/src/jimgw/single_event/waveform.py @@ -93,7 +93,8 @@ def __init__(self, f_ref: float = 20.0, use_lambda_tildes: bool = False): self.f_ref = f_ref self.use_lambda_tildes = use_lambda_tildes - def __call__(self, frequency: Array, params: dict) -> dict: + def __call__(self, frequency: Float[Array, " n_dim"], params: dict[str, Float] + ) -> dict[str, Float[Array, " n_dim"]] output = {} if self.use_lambda_tildes: @@ -151,7 +152,8 @@ def __init__( self.use_lambda_tildes = use_lambda_tildes self.no_taper = no_taper - def __call__(self, frequency: Array, params: dict) -> dict: + def __call__(self, frequency: Float[Array, " n_dim"], params: dict[str, Float] + ) -> dict[str, Float[Array, " n_dim"]] output = {} if self.use_lambda_tildes: From 28264fb8def259e04a7230b5bb835bea5958160b Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Sun, 26 May 2024 15:08:42 -0400 Subject: [PATCH 7/8] Update waveform.py minor bug fix --- src/jimgw/single_event/waveform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jimgw/single_event/waveform.py b/src/jimgw/single_event/waveform.py index adc2f5b7..b763e153 100644 --- a/src/jimgw/single_event/waveform.py +++ b/src/jimgw/single_event/waveform.py @@ -94,7 +94,7 @@ def __init__(self, f_ref: float = 20.0, use_lambda_tildes: bool = False): self.use_lambda_tildes = use_lambda_tildes def __call__(self, frequency: Float[Array, " n_dim"], params: dict[str, Float] - ) -> dict[str, Float[Array, " n_dim"]] + ) -> dict[str, Float[Array, " n_dim"]]: output = {} if self.use_lambda_tildes: @@ -153,7 +153,7 @@ def __init__( self.no_taper = no_taper def __call__(self, frequency: Float[Array, " n_dim"], params: dict[str, Float] - ) -> dict[str, Float[Array, " n_dim"]] + ) -> dict[str, Float[Array, " n_dim"]]: output = {} if self.use_lambda_tildes: From 002843eb90acfdf7eba15607109d2ad18ac288fb Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Sun, 26 May 2024 15:15:07 -0400 Subject: [PATCH 8/8] Update waveform.py --- src/jimgw/single_event/waveform.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/jimgw/single_event/waveform.py b/src/jimgw/single_event/waveform.py index b763e153..e2084ea1 100644 --- a/src/jimgw/single_event/waveform.py +++ b/src/jimgw/single_event/waveform.py @@ -93,7 +93,8 @@ def __init__(self, f_ref: float = 20.0, use_lambda_tildes: bool = False): self.f_ref = f_ref self.use_lambda_tildes = use_lambda_tildes - def __call__(self, frequency: Float[Array, " n_dim"], params: dict[str, Float] + def __call__( + self, frequency: Float[Array, " n_dim"], params: dict[str, Float] ) -> dict[str, Float[Array, " n_dim"]]: output = {} @@ -152,7 +153,8 @@ def __init__( self.use_lambda_tildes = use_lambda_tildes self.no_taper = no_taper - def __call__(self, frequency: Float[Array, " n_dim"], params: dict[str, Float] + def __call__( + self, frequency: Float[Array, " n_dim"], params: dict[str, Float] ) -> dict[str, Float[Array, " n_dim"]]: output = {}