Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mass & sky position transform #118

Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
39126f5
Add inverse mass transform
thomasckng Jul 30, 2024
a0161ee
Merge pull request #5 from kazewong/98-moving-naming-tracking-into-ji…
thomasckng Jul 31, 2024
dfdfffa
Add mass transform
thomasckng Jul 31, 2024
9d87e58
Add simplex transform
thomasckng Jul 31, 2024
5f33346
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
thomasckng Jul 31, 2024
1e389bb
Add UniformComponentMassPrior
thomasckng Jul 31, 2024
fd32f20
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
thomasckng Jul 31, 2024
1aaf5ea
Remove transformation
thomasckng Jul 31, 2024
5bfdda1
Remove prior
thomasckng Jul 31, 2024
77a6ad1
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
thomasckng Jul 31, 2024
2401df8
Add mass transform
thomasckng Jul 31, 2024
f932c77
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
thomasckng Jul 31, 2024
b5e06a6
Solve conflict
thomasckng Aug 1, 2024
b7d08d3
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
thomasckng Aug 1, 2024
8e5b326
Add sky position transform
thomasckng Aug 1, 2024
10d51b2
Modify sky position transform
thomasckng Aug 1, 2024
8368d00
Change util func name
thomasckng Aug 1, 2024
b4f6052
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
thomasckng Aug 1, 2024
1cb0d11
Revert "Merge branch '98-moving-naming-tracking-into-jim-class-from-p…
thomasckng Aug 1, 2024
78e93c5
Merge
thomasckng Aug 1, 2024
080bd8b
Modify integration test
thomasckng Aug 1, 2024
4058d32
Reformat
thomasckng Aug 1, 2024
cdf771d
Add typecheck
thomasckng Aug 1, 2024
02c5650
minor typo
thomasckng Aug 1, 2024
ce7ac34
Rename sampler
thomasckng Aug 1, 2024
7d44aa4
Fix test
thomasckng Aug 1, 2024
e9288c8
Fix BoundToUnbound transform
thomasckng Aug 1, 2024
fe500a7
Use ifos list
thomasckng Aug 1, 2024
0d28520
Fix jim summary and get_samples
thomasckng Aug 1, 2024
f7e3fe8
Fix jim output functions
thomasckng Aug 1, 2024
68bef54
Modify Transform
thomasckng Aug 1, 2024
87593e1
Fix jim output
thomasckng Aug 1, 2024
730fe31
Add comment
thomasckng Aug 1, 2024
5a5ff2f
Add sky position transform
thomasckng Aug 2, 2024
277e893
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
kazewong Aug 2, 2024
bd915ef
Add powerlaw transform back
kazewong Aug 2, 2024
ba86a65
Move single_event prior and transform
thomasckng Aug 2, 2024
3e1ea71
Tidy up test
thomasckng Aug 2, 2024
6ad882a
Add utils.py
thomasckng Aug 2, 2024
ede2b99
Move log_i0
thomasckng Aug 2, 2024
e764696
Fixing check
thomasckng Aug 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 80 additions & 28 deletions src/jimgw/jim.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jaxtyping import Array, Float, PRNGKeyArray

from jimgw.base import LikelihoodBase
from jimgw.prior import Prior, trace_prior_parent
from jimgw.prior import Prior
from jimgw.transforms import BijectiveTransform, NtoMTransform


Expand Down Expand Up @@ -41,15 +41,18 @@ def __init__(
self.parameter_names = prior.parameter_names

if len(sample_transforms) == 0:
print("No sample transforms provided. Using prior parameters as sampling parameters")
print(
"No sample transforms provided. Using prior parameters as sampling parameters"
)
else:
print("Using sample transforms")
for transform in sample_transforms:
self.parameter_names = transform.propagate_name(self.parameter_names)

if len(likelihood_transforms) == 0:
print("No likelihood transforms provided. Using prior parameters as likelihood parameters")

print(
"No likelihood transforms provided. Using prior parameters as likelihood parameters"
)

seed = kwargs.get("seed", 0)

Expand All @@ -67,7 +70,7 @@ def __init__(
self.prior.n_dim, num_layers, hidden_size, num_bins, subkey
)

self.Sampler = Sampler(
self.sampler = Sampler(
self.prior.n_dim,
rng_key,
None, # type: ignore
Expand All @@ -91,22 +94,23 @@ def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]:
def posterior(self, params: Float[Array, " n_dim"], data: dict):
named_params = self.add_name(params)
transform_jacobian = 0.0
for transform in self.sample_transforms:
for transform in reversed(self.sample_transforms):
named_params, jacobian = transform.inverse(named_params)
transform_jacobian += jacobian
prior = self.prior.log_prob(named_params) + transform_jacobian
for transform in self.likelihood_transforms:
named_params = transform.forward(named_params)
named_params = jax.tree.map(lambda x:x[0], named_params) # This [0] should be consolidate
return self.likelihood.evaluate(named_params, data) + prior[0] # This prior [0] should be consolidate
return (
self.likelihood.evaluate(named_params, data) + prior
)

def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])):
if initial_guess.size == 0:
initial_guess_named = self.prior.sample(key, self.Sampler.n_chains)
initial_guess_named = self.prior.sample(key, self.sampler.n_chains)
for transform in self.sample_transforms:
initial_guess_named = jax.vmap(transform.forward)(initial_guess_named)
initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T[0] # This [0] should be consolidate
self.Sampler.sample(initial_guess, None) # type: ignore
initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T
self.sampler.sample(initial_guess, None) # type: ignore

def maximize_likelihood(
self,
Expand All @@ -133,36 +137,64 @@ def negative_posterior(x: Float[Array, " n_dim"]):
best_fit = optimizer.get_result()[0]
return best_fit

def print_summary(self, transform: bool = True):
def print_summary(self):
"""
Generate summary of the run

"""

train_summary = self.Sampler.get_sampler_state(training=True)
production_summary = self.Sampler.get_sampler_state(training=False)
train_summary = self.sampler.get_sampler_state(training=True)
production_summary = self.sampler.get_sampler_state(training=False)

training_chain = train_summary["chains"].reshape(-1, self.prior.n_dim).T
training_chain = self.prior.add_name(training_chain)
if transform:
training_chain = self.prior.transform(training_chain)
training_chain = train_summary["chains"].reshape(-1, len(self.parameter_names))
if self.sample_transforms:
# Need rewrite to vectorize
transformed_chain = {}
named_sample = self.add_name(training_chain[0])
for transform in self.sample_transforms:
named_sample = transform.backward(named_sample)
for key, value in named_sample.items():
transformed_chain[key] = [value]
for sample in training_chain[1:]:
named_sample = self.add_name(sample)
for transform in self.sample_transforms:
named_sample = transform.backward(named_sample)
for key, value in named_sample.items():
transformed_chain[key].append(value)
training_chain = transformed_chain
else:
training_chain = self.add_name(training_chain)
training_log_prob = train_summary["log_prob"]
training_local_acceptance = train_summary["local_accs"]
training_global_acceptance = train_summary["global_accs"]
training_loss = train_summary["loss_vals"]

production_chain = production_summary["chains"].reshape(-1, self.prior.n_dim).T
production_chain = self.prior.add_name(production_chain)
if transform:
production_chain = self.prior.transform(production_chain)
production_chain = production_summary["chains"].reshape(-1, len(self.parameter_names))
if self.sample_transforms:
# Need rewrite to vectorize
transformed_chain = {}
named_sample = self.add_name(production_chain[0])
for transform in self.sample_transforms:
named_sample = transform.backward(named_sample)
for key, value in named_sample.items():
transformed_chain[key] = [value]
for sample in production_chain[1:]:
named_sample = self.add_name(sample)
for transform in self.sample_transforms:
named_sample = transform.backward(named_sample)
for key, value in named_sample.items():
transformed_chain[key].append(value)
production_chain = transformed_chain
else:
production_chain = self.add_name(production_chain)
production_log_prob = production_summary["log_prob"]
production_local_acceptance = production_summary["local_accs"]
production_global_acceptance = production_summary["global_accs"]

print("Training summary")
print("=" * 10)
for key, value in training_chain.items():
print(f"{key}: {value.mean():.3f} +/- {value.std():.3f}")
print(f"{key}: {jnp.array(value).mean():.3f} +/- {jnp.array(value).std():.3f}")
print(
f"Log probability: {training_log_prob.mean():.3f} +/- {training_log_prob.std():.3f}"
)
Expand All @@ -179,7 +211,7 @@ def print_summary(self, transform: bool = True):
print("Production summary")
print("=" * 10)
for key, value in production_chain.items():
print(f"{key}: {value.mean():.3f} +/- {value.std():.3f}")
print(f"{key}: {jnp.array(value).mean():.3f} +/- {jnp.array(value).std():.3f}")
print(
f"Log probability: {production_log_prob.mean():.3f} +/- {production_log_prob.std():.3f}"
)
Expand All @@ -206,12 +238,32 @@ def get_samples(self, training: bool = False) -> dict:

"""
if training:
chains = self.Sampler.get_sampler_state(training=True)["chains"]
chains = self.sampler.get_sampler_state(training=True)["chains"]
else:
chains = self.sampler.get_sampler_state(training=False)["chains"]

# Need rewrite to output chains instead of flattened samples and vectorize
chains = chains.reshape(-1, len(self.parameter_names))
if self.sample_transforms:
transformed_chain = {}
named_sample = self.add_name(chains[0])
for transform in self.sample_transforms:
named_sample = transform.backward(named_sample)
for key, value in named_sample.items():
transformed_chain[key] = [value]
for sample in chains[1:]:
named_sample = self.add_name(sample)
for transform in self.sample_transforms:
named_sample = transform.backward(named_sample)
for key, value in named_sample.items():
transformed_chain[key].append(value)
output = transformed_chain
else:
chains = self.Sampler.get_sampler_state(training=False)["chains"]
output = self.add_name(chains)

chains = self.prior.transform(self.prior.add_name(chains.transpose(2, 0, 1)))
return chains
for key in output.keys():
output[key] = jnp.array(output[key])
return output

def plot(self):
pass
2 changes: 1 addition & 1 deletion src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def __init__(


@jaxtyped(typechecker=typechecker)
class UniformInComponentsChirpMassPrior(PowerLawPrior):
class UniformComponentChirpMassPrior(PowerLawPrior):
"""
A prior in the range [xmin, xmax) for chirp mass which assumes the
component masses to be uniformly distributed.
Expand Down
Loading
Loading