Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Haukekoehn train fluxes #15

Merged
merged 18 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions benchmarks/GRB/benchmark_afterglowpy_tophat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,40 @@
model_dir = f"../../trained_models/GRB/afterglowpy/{name}/"
FILTERS = ["radio-6GHz", "radio-3GHz"]#["radio-3GHz", "radio-6GHz", "bessellv", "X-ray-1keV"]


for metric_name in ["$\\mathcal{L}_2$", "$\\mathcal{L}_\infty$"]:
if metric_name == "$\\mathcal{L}_2$":
file_ending = "L2"
else:
file_ending = "Linf"


B = Benchmarker(name = "tophat",
B = Benchmarker(name = name,
parameter_grid = parameter_grid,
model_dir = model_dir,
MODEL = AfterglowpyLightcurvemodel,
filters = FILTERS,
n_test_data = 2000,
metric_name = metric_name,
remake_test_data = True,
remake_test_data = False,
jet_type = -1,
)

fig, ax = B.plot_error_distribution("radio-6GHz")


for filt in FILTERS:

fig, ax = B.plot_lightcurves_mismatch(filter =filt)
fig.savefig(f"./figures/benchmark_{filt}_{file_ending}.pdf", dpi = 200)
fig, ax = B.plot_lightcurves_mismatch(filter =filt, parameter_labels = ["$\\iota$", "$\log_{10}(E_0)$", "$\\theta_{\\mathrm{core}}$", "$\log_{10}(n_{\mathrm{ism}})$", "$p$", "$\\epsilon_E$", "$\\epsilon_B$"])
fig.savefig(f"./benchmarks/{name}/benchmark_{filt}_{file_ending}.pdf", dpi = 200)

B.print_correlations(filter = filt)


if metric_name == "$\\mathcal{L}_\infty$":
fig, ax = B.plot_error_distribution(filt)
fig.savefig(f"./benchmarks/{name}/error_distribution_{filt}.pdf", dpi = 200)


fig, ax = B.plot_worst_lightcurve(filter = filt)
fig.savefig(f"./figures/worst_lightcurve_{filt}_{file_ending}.pdf", dpi = 200)

11 changes: 7 additions & 4 deletions benchmarks/KN/benchmark_Bu2019lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@
fig.savefig(f"./figures/benchmark_{filt}_{file_ending}.pdf", dpi = 200)

B.print_correlations(filter = filt)


fig, ax = B.plot_worst_lightcurve(filter = filt)
fig.savefig(f"./figures/worst_lightcurve_{filt}_{file_ending}.pdf", dpi = 200)


if metric_name == "$\\mathcal{L}_\infty$":
fig, ax = B.plot_error_distribution(filt)
fig.savefig(f"./benchmarks/{name}/error_distribution_{filt}.pdf", dpi = 200)


fig, ax = B.plot_worst_lightcurves()
fig.savefig(f"./benchmarks/{name}/worst_lightcurves_{file_ending}.pdf", dpi = 200)


2 changes: 1 addition & 1 deletion examples/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1 @@
figures/
results_*.npz
27 changes: 27 additions & 0 deletions examples/GRB/bash.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/bin/bash -l
#Set job requirements
#SBATCH -N 1
#SBATCH -n 1
#SBATCH -p gpu
#SBATCH -t 00:30:00
#SBATCH --gpus-per-node=1
#SBATCH --cpus-per-gpu=1
#SBATCH --mem-per-gpu=5G
#SBATCH --output=outdir_GRB170817_tophat/log.out
#SBATCH --job-name=GRB170817

now=$(date)
echo "$now"

# Loading modules
# module load 2024
# module load Python/3.10.4-GCCcore-11.3.0
conda activate /home/twouters2/miniconda3/envs/ninjax

# Display GPU name
nvidia-smi --query-gpu=name --format=csv,noheader

# Run the script
python run_GRB170817_tophat.py

echo "DONE"
213 changes: 213 additions & 0 deletions examples/GRB/injection_gaussian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
"""Injection runs with afterglowpy gaussian"""

import os
import jax
print(f"GPU found? {jax.devices()}")
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
import numpy as np
import matplotlib.pyplot as plt
import corner

from fiesta.inference.lightcurve_model import AfterglowpyPCA, PCALightcurveModel
from fiesta.inference.injection import InjectionRecoveryAfterglowpy
from fiesta.inference.likelihood import EMLikelihood
from fiesta.inference.prior import Uniform, CompositePrior, Constraint
from fiesta.inference.prior_dict import ConstrainedPrior
from fiesta.inference.fiesta import Fiesta
from fiesta.utils import load_event_data, write_event_data

import time
start_time = time.time()

################
### Preamble ###
################

jax.config.update("jax_enable_x64", True)

params = {"axes.grid": True,
"text.usetex" : True,
"font.family" : "serif",
"ytick.color" : "black",
"xtick.color" : "black",
"axes.labelcolor" : "black",
"axes.edgecolor" : "black",
"font.serif" : ["Computer Modern Serif"],
"xtick.labelsize": 16,
"ytick.labelsize": 16,
"axes.labelsize": 16,
"legend.fontsize": 16,
"legend.title_fontsize": 16,
"figure.titlesize": 16}

plt.rcParams.update(params)

default_corner_kwargs = dict(bins=40,
smooth=1.,
label_kwargs=dict(fontsize=16),
title_kwargs=dict(fontsize=16),
color="blue",
# quantiles=[],
# levels=[0.9],
plot_density=True,
plot_datapoints=False,
fill_contours=True,
max_n_ticks=4,
min_n_ticks=3,
save=False,
truth_color="red")


##############
### MODEL ###
##############

name = "gaussian"
model_dir = f"../../flux_models/afterglowpy_{name}/model"
FILTERS = ["radio-3GHz", "radio-6GHz", "X-ray-1keV", "bessellv"]

model = AfterglowpyPCA(name,
model_dir,
filters = FILTERS)


###################
### INJECT ###
### AFTERGLOWPY ###
###################

trigger_time = 58849 # 01-01-2020 in mjd
remake_injection = False
injection_dict = {"inclination_EM": 0.174, "log10_E0": 54.4, "thetaCore": 0.14, "alphaWing": 3, "p": 2.6, "log10_n0": -2, "log10_epsilon_e": -2.06, "log10_epsilon_B": -4.2, "luminosity_distance": 40.0}

if remake_injection:
injection = InjectionRecoveryAfterglowpy(injection_dict, jet_type = 0, filters = FILTERS, N_datapoints = 70, error_budget = 0.5, tmin = 1, tmax = 2000, trigger_time = trigger_time)
injection.create_injection()
data = injection.data
write_event_data("./injection_gaussian/injection_gaussian.dat", data)

data = load_event_data("./injection_gaussian/injection_gaussian.dat")
#############################
### PRIORS AND LIKELIHOOD ###
#############################

inclination_EM = Uniform(xmin=0.0, xmax=np.pi/2, naming=['inclination_EM'])
log10_E0 = Uniform(xmin=47.0, xmax=57.0, naming=['log10_E0'])
thetaCore = Uniform(xmin=0.01, xmax=np.pi/5, naming=['thetaCore'])
alphaWing = Uniform(xmin = 0.2, xmax = 3.5, naming= ["alphaWing"])
thetaWing = Constraint(xmin = 0, xmax = np.pi/2, naming = ["thetaWing"])
log10_n0 = Uniform(xmin=-6.0, xmax=2.0, naming=['log10_n0'])
p = Uniform(xmin=2.01, xmax=3.0, naming=['p'])
log10_epsilon_e = Uniform(xmin=-4.0, xmax=0.0, naming=['log10_epsilon_e'])
log10_epsilon_B = Uniform(xmin=-8.0, xmax=0.0, naming=['log10_epsilon_B'])
epsilon_tot = Constraint(xmin = 0, xmax = 1, naming = ["epsilon_tot"])

# luminosity_distance = Uniform(xmin=30.0, xmax=50.0, naming=['luminosity_distance'])
def conversion_function(sample):
converted_sample = sample
converted_sample["thetaWing"] = converted_sample["thetaCore"] * converted_sample["alphaWing"]
converted_sample["epsilon_tot"] = 10**(converted_sample["log10_epsilon_B"]) + 10**(converted_sample["log10_epsilon_e"])
return converted_sample

prior_list = [inclination_EM,
log10_E0,
thetaCore,
alphaWing,
log10_n0,
p,
log10_epsilon_e,
log10_epsilon_B,
thetaWing,
epsilon_tot]

prior = ConstrainedPrior(prior_list, conversion_function)

detection_limit = None
likelihood = EMLikelihood(model,
data,
FILTERS,
tmax = 2000.0,
trigger_time=trigger_time,
detection_limit = detection_limit,
fixed_params={"luminosity_distance": 40.0},
error_budget = 1e-5)


##############
### FIESTA ###
##############

mass_matrix = jnp.eye(prior.n_dim)
eps = 5e-3
local_sampler_arg = {"step_size": mass_matrix * eps}

# Save for postprocessing
outdir = f"./injection_{name}/"
if not os.path.exists(outdir):
os.makedirs(outdir)

fiesta = Fiesta(likelihood,
prior,
n_chains = 1_000,
n_loop_training = 7,
n_loop_production = 3,
num_layers = 4,
hidden_size = [64, 64],
n_epochs = 20,
n_local_steps = 50,
n_global_steps = 200,
local_sampler_arg=local_sampler_arg,
outdir = outdir)

fiesta.sample(jax.random.PRNGKey(42))

fiesta.print_summary()

name = outdir + f'results_training.npz'
print(f"Saving samples to {name}")
state = fiesta.Sampler.get_sampler_state(training=True)
chains, log_prob, local_accs, global_accs, loss_vals = state["chains"], state[
"log_prob"], state["local_accs"], state["global_accs"], state["loss_vals"]
local_accs = jnp.mean(local_accs, axis=0)
global_accs = jnp.mean(global_accs, axis=0)
np.savez(name, log_prob=log_prob, local_accs=local_accs,
global_accs=global_accs, loss_vals=loss_vals)

# - production phase
name = outdir + f'results_production.npz'
print(f"Saving samples to {name}")
state = fiesta.Sampler.get_sampler_state(training=False)
chains, log_prob, local_accs, global_accs = state["chains"], state[
"log_prob"], state["local_accs"], state["global_accs"]
local_accs = jnp.mean(local_accs, axis=0)
global_accs = jnp.mean(global_accs, axis=0)
np.savez(name, chains=chains, log_prob=log_prob,
local_accs=local_accs, global_accs=global_accs)

################
### PLOTTING ###
################
# Fixed names: do not include them in the plotting, as will break corner
parameter_names = prior.naming
truths = [injection_dict[key] for key in parameter_names]

n_chains, n_steps, n_dim = np.shape(chains)
samples = np.reshape(chains, (n_chains * n_steps, n_dim))
samples = np.asarray(samples) # convert from jax.numpy array to numpy array for corner consumption

corner.corner(samples, labels = parameter_names, hist_kwargs={'density': True}, truths = truths, **default_corner_kwargs)
plt.savefig(os.path.join(outdir, "corner.png"), bbox_inches = 'tight')
plt.close()

end_time = time.time()
runtime_seconds = end_time - start_time
number_of_minutes = runtime_seconds // 60
number_of_seconds = np.round(runtime_seconds % 60, 2)
print(f"Total runtime: {number_of_minutes} m {number_of_seconds} s")

print("Plotting lightcurves")
fiesta.plot_lightcurves()
print("Plotting lightcurves . . . done")

print("DONE")
1 change: 1 addition & 0 deletions examples/GRB/injection_gaussian/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

Binary file added examples/GRB/injection_gaussian/corner.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
70 changes: 70 additions & 0 deletions examples/GRB/injection_gaussian/injection_gaussian.dat
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
2020-01-02T00:00:00.000 radio-3GHz 6.473406 0.500000
2020-04-10T22:48:00.000 radio-3GHz 12.246090 0.500000
2020-07-19T21:36:00.000 radio-3GHz 13.860922 0.500000
2020-10-27T20:24:00.000 radio-3GHz 15.145347 0.500000
2021-02-04T19:12:00.000 radio-3GHz 16.200114 0.500000
2021-05-15T18:00:00.000 radio-3GHz 17.037075 0.500000
2021-08-23T16:48:00.000 radio-3GHz 17.709738 0.500000
2021-12-01T15:36:00.000 radio-3GHz 18.265443 0.500000
2022-03-11T14:24:00.000 radio-3GHz 18.738014 0.500000
2022-06-19T13:12:00.000 radio-3GHz 19.142404 0.500000
2022-09-27T12:00:00.000 radio-3GHz 19.496765 0.500000
2023-01-05T10:48:00.000 radio-3GHz 19.810069 0.500000
2023-04-15T09:36:00.000 radio-3GHz 20.091715 0.500000
2023-07-24T08:24:00.000 radio-3GHz 20.346828 0.500000
2023-11-01T07:12:00.000 radio-3GHz 20.578875 0.500000
2024-02-09T06:00:00.000 radio-3GHz 20.791826 0.500000
2024-05-19T04:48:00.000 radio-3GHz 20.989639 0.500000
2024-08-27T03:36:00.000 radio-3GHz 21.171644 0.500000
2024-12-05T02:24:00.000 radio-3GHz 21.341979 0.500000
2025-03-15T01:12:00.000 radio-3GHz 21.500904 0.500000
2025-06-23T00:00:00.000 radio-3GHz 21.652456 0.500000
2020-01-02T00:00:00.000 radio-6GHz 6.642397 0.500000
2020-05-05T22:30:00.000 radio-6GHz 13.306995 0.500000
2020-09-07T21:00:00.000 radio-6GHz 15.133077 0.500000
2021-01-10T19:30:00.000 radio-6GHz 16.559615 0.500000
2021-05-15T18:00:00.000 radio-6GHz 17.639135 0.500000
2021-09-17T16:30:00.000 radio-6GHz 18.460232 0.500000
2022-01-20T15:00:00.000 radio-6GHz 19.112092 0.500000
2022-05-25T13:30:00.000 radio-6GHz 19.646921 0.500000
2022-09-27T12:00:00.000 radio-6GHz 20.098825 0.500000
2023-01-30T10:30:00.000 radio-6GHz 20.487108 0.500000
2023-06-04T09:00:00.000 radio-6GHz 20.824446 0.500000
2023-10-07T07:30:00.000 radio-6GHz 21.125451 0.500000
2024-02-09T06:00:00.000 radio-6GHz 21.393886 0.500000
2024-06-13T04:30:00.000 radio-6GHz 21.637560 0.500000
2024-10-16T03:00:00.000 radio-6GHz 21.859776 0.500000
2025-02-18T01:30:00.000 radio-6GHz 22.064226 0.500000
2025-06-23T00:00:00.000 radio-6GHz 22.254516 0.500000
2020-01-02T00:00:00.000 X-ray-1keV 21.209144 0.500000
2020-04-06T04:34:17.143 X-ray-1keV 27.962903 0.500000
2020-07-10T09:08:34.286 X-ray-1keV 29.537845 0.500000
2020-10-13T13:42:51.429 X-ray-1keV 30.789609 0.500000
2021-01-16T18:17:08.571 X-ray-1keV 31.828374 0.500000
2021-04-21T22:51:25.714 X-ray-1keV 32.666384 0.500000
2021-07-26T03:25:42.857 X-ray-1keV 33.344842 0.500000
2021-10-29T08:00:00.000 X-ray-1keV 33.904345 0.500000
2022-02-01T12:34:17.143 X-ray-1keV 34.378401 0.500000
2022-05-07T17:08:34.286 X-ray-1keV 34.787830 0.500000
2022-08-10T21:42:51.429 X-ray-1keV 35.144792 0.500000
2022-11-14T02:17:08.571 X-ray-1keV 35.464131 0.500000
2023-02-17T06:51:25.714 X-ray-1keV 35.747110 0.500000
2023-05-23T11:25:42.857 X-ray-1keV 36.004242 0.500000
2023-08-26T16:00:00.000 X-ray-1keV 36.238919 0.500000
2023-11-29T20:34:17.143 X-ray-1keV 36.454925 0.500000
2024-03-04T01:08:34.286 X-ray-1keV 36.654020 0.500000
2024-06-07T05:42:51.429 X-ray-1keV 36.837070 0.500000
2024-09-10T10:17:08.571 X-ray-1keV 37.010718 0.500000
2024-12-14T14:51:25.714 X-ray-1keV 37.169814 0.500000
2025-03-19T19:25:42.857 X-ray-1keV 37.321108 0.500000
2025-06-23T00:00:00.000 X-ray-1keV 37.465122 0.500000
2020-01-02T00:00:00.000 bessellv 15.916906 0.500000
2020-08-11T02:40:00.000 bessellv 24.685939 0.500000
2021-03-21T05:20:00.000 bessellv 27.115049 0.500000
2021-10-29T08:00:00.000 bessellv 28.612108 0.500000
2022-06-08T10:40:00.000 bessellv 29.619683 0.500000
2023-01-16T13:20:00.000 bessellv 30.363185 0.500000
2023-08-26T16:00:00.000 bessellv 30.946682 0.500000
2024-04-04T18:40:00.000 bessellv 31.423505 0.500000
2024-11-12T21:20:00.000 bessellv 31.826433 0.500000
2025-06-23T00:00:00.000 bessellv 32.172885 0.500000
Binary file added examples/GRB/injection_gaussian/lightcurves.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading