Skip to content

Commit

Permalink
Adds support for IMRPhenomPv2 (#196)
Browse files Browse the repository at this point in the history
* update to ml4gw 0.6.2

* add optional conversion function to parameter sampler

* add precessing and aligned versions of cbc prior

* add precessing parameters

* update to 0.6.3

* restrict torch to < 2.5.0

* scale lr by sqrt world size

* fix pre-commit
  • Loading branch information
EthanMarx authored Jan 12, 2025
1 parent 769e978 commit 1818245
Show file tree
Hide file tree
Showing 5 changed files with 933 additions and 770 deletions.
26 changes: 24 additions & 2 deletions amplfi/train/data/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Dict
from typing import Callable, Dict, Optional

import torch

Expand All @@ -25,18 +25,40 @@ def forward(


class ParameterSampler(torch.nn.Module):
def __init__(self, **parameters: Callable):
def __init__(
self,
conversion_function: Optional[Callable] = None,
**parameters: Callable
):
"""
A class for sampling parameters from a prior distribution
Args:
conversion_function:
A callable that takes a dictionary of sampled parameters
and returns a dictionary of waveform generation parameters
**parameters:
A dictionary of parameter samplers that take an integer N
and return a tensor of shape (N, ...) representing
samples from the prior distribution
"""
super().__init__()
self.parameters = parameters
self.conversion_function = conversion_function or (lambda x: x)

def forward(
self,
N: int,
device: str = "cpu",
):
# sample parameters from prior
parameters = {
k: v.sample((N,)).to(device) for k, v in self.parameters.items()
}
# perform any necessary conversions
# to from sampled parameters to
# waveform generation parameters
parameters = self.conversion_function(parameters)
return parameters


Expand Down
3 changes: 2 additions & 1 deletion amplfi/train/models/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import math
import sys
from pathlib import Path
from typing import Optional
Expand Down Expand Up @@ -99,7 +100,7 @@ def configure_optimizers(self):
else:
world_size = torch.distributed.get_world_size()

lr = self.hparams.learning_rate * world_size
lr = self.hparams.learning_rate * math.sqrt(world_size)
optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=lr,
Expand Down
109 changes: 108 additions & 1 deletion amplfi/train/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

import torch
from ml4gw import distributions
from ml4gw.waveforms.conversion import (
bilby_spins_to_lalsim,
chirp_mass_and_mass_ratio_to_components,
)
from torch.distributions import Uniform

from .data.utils.utils import ParameterSampler, ParameterTransformer
Expand All @@ -24,7 +28,10 @@ def sg_prior() -> ParameterSampler:


# priors and parameter transformers for cbc use case
def cbc_prior() -> ParameterSampler:
def aligned_cbc_prior() -> ParameterSampler:
"""
Prior for aligned-spin CBC waveform generation, e.g with IMRPhenomD
"""
return ParameterSampler(
chirp_mass=Uniform(
torch.as_tensor(10, dtype=torch.float32),
Expand Down Expand Up @@ -57,6 +64,106 @@ def cbc_prior() -> ParameterSampler:
)


def precessing_cbc_prior() -> ParameterSampler:
"""
Prior for precessing-spin CBC waveform generation, e.g with IMRPhenomPv2
"""

def conversion_function(
parameters: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
chirp_mass, mass_ratio = parameters.pop("chirp_mass"), parameters.pop(
"mass_ratio"
)
mass_1, mass_2 = chirp_mass_and_mass_ratio_to_components(
chirp_mass, mass_ratio
)

a_1, a_2 = parameters.pop("a_1"), parameters.pop("a_2")
tilt_1, tilt_2 = parameters.pop("tilt_1"), parameters.pop("tilt_2")
inclination = parameters.pop("inclination")
phi_jl, phi_12 = parameters.pop("phi_jl"), parameters.pop("phi_12")
phic = parameters.pop("phic")

# TODO: hard coding f_ref = 40 here b/c not sure best way to link this
# to the f_ref specified in the config file
incl, s1x, s1y, s1z, s2x, s2y, s2z = bilby_spins_to_lalsim(
inclination,
phi_jl,
tilt_1,
tilt_2,
phi_12,
a_1,
a_2,
mass_1,
mass_2,
40,
torch.zeros(len(inclination), device=inclination.device),
)

output = {}
output["chirp_mass"] = chirp_mass
output["mass_ratio"] = mass_ratio
output["s1x"] = s1x
output["s1y"] = s1y
output["s1z"] = s1z
output["s2x"] = s2x
output["s2y"] = s2y
output["s2z"] = s2z
output["distance"] = parameters["distance"]
output["inclination"] = incl
output["phic"] = phic
return output

return ParameterSampler(
conversion_function=conversion_function,
chirp_mass=Uniform(
torch.as_tensor(10, dtype=torch.float32),
torch.as_tensor(100, dtype=torch.float32),
),
mass_ratio=Uniform(
torch.as_tensor(0.125, dtype=torch.float32),
torch.as_tensor(0.999, dtype=torch.float32),
),
distance=Uniform(
torch.as_tensor(100, dtype=torch.float32),
torch.as_tensor(3100, dtype=torch.float32),
),
inclination=distributions.Sine(
torch.as_tensor(0, dtype=torch.float32),
torch.as_tensor(torch.pi, dtype=torch.float32),
),
phic=Uniform(
torch.as_tensor(0, dtype=torch.float32),
torch.as_tensor(2 * torch.pi, dtype=torch.float32),
),
a_1=Uniform(
torch.as_tensor(0, dtype=torch.float32),
torch.as_tensor(0.999, dtype=torch.float32),
),
a_2=Uniform(
torch.as_tensor(0, dtype=torch.float32),
torch.as_tensor(0.999, dtype=torch.float32),
),
tilt_1=Uniform(
torch.as_tensor(0, dtype=torch.float32),
torch.as_tensor(pi, dtype=torch.float32),
),
tilt_2=Uniform(
torch.as_tensor(0, dtype=torch.float32),
torch.as_tensor(pi, dtype=torch.float32),
),
phi_jl=Uniform(
torch.as_tensor(0, dtype=torch.float32),
torch.as_tensor(2 * torch.pi, dtype=torch.float32),
),
phi_12=Uniform(
torch.as_tensor(0, dtype=torch.float32),
torch.as_tensor(2 * torch.pi, dtype=torch.float32),
),
)


def cbc_testing_delta_function_prior() -> ParameterSampler:
return ParameterSampler(
chirp_mass=distributions.DeltaFunction(
Expand Down
Loading

0 comments on commit 1818245

Please sign in to comment.