Skip to content

Commit

Permalink
Merge pull request #15 from ThibeauWouters/haukekoehn-train_fluxes
Browse files Browse the repository at this point in the history
Haukekoehn train fluxes
  • Loading branch information
ThibeauWouters authored Dec 19, 2024
2 parents 5ba014e + d90fba9 commit 187c291
Show file tree
Hide file tree
Showing 193 changed files with 3,681 additions and 411 deletions.
21 changes: 16 additions & 5 deletions benchmarks/GRB/benchmark_afterglowpy_tophat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,40 @@
model_dir = f"../../trained_models/GRB/afterglowpy/{name}/"
FILTERS = ["radio-6GHz", "radio-3GHz"]#["radio-3GHz", "radio-6GHz", "bessellv", "X-ray-1keV"]


for metric_name in ["$\\mathcal{L}_2$", "$\\mathcal{L}_\infty$"]:
if metric_name == "$\\mathcal{L}_2$":
file_ending = "L2"
else:
file_ending = "Linf"


B = Benchmarker(name = "tophat",
B = Benchmarker(name = name,
parameter_grid = parameter_grid,
model_dir = model_dir,
MODEL = AfterglowpyLightcurvemodel,
filters = FILTERS,
n_test_data = 2000,
metric_name = metric_name,
remake_test_data = True,
remake_test_data = False,
jet_type = -1,
)

fig, ax = B.plot_error_distribution("radio-6GHz")


for filt in FILTERS:

fig, ax = B.plot_lightcurves_mismatch(filter =filt)
fig.savefig(f"./figures/benchmark_{filt}_{file_ending}.pdf", dpi = 200)
fig, ax = B.plot_lightcurves_mismatch(filter =filt, parameter_labels = ["$\\iota$", "$\log_{10}(E_0)$", "$\\theta_{\\mathrm{core}}$", "$\log_{10}(n_{\mathrm{ism}})$", "$p$", "$\\epsilon_E$", "$\\epsilon_B$"])
fig.savefig(f"./benchmarks/{name}/benchmark_{filt}_{file_ending}.pdf", dpi = 200)

B.print_correlations(filter = filt)


if metric_name == "$\\mathcal{L}_\infty$":
fig, ax = B.plot_error_distribution(filt)
fig.savefig(f"./benchmarks/{name}/error_distribution_{filt}.pdf", dpi = 200)


fig, ax = B.plot_worst_lightcurve(filter = filt)
fig.savefig(f"./figures/worst_lightcurve_{filt}_{file_ending}.pdf", dpi = 200)

11 changes: 7 additions & 4 deletions benchmarks/KN/benchmark_Bu2019lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@
fig.savefig(f"./figures/benchmark_{filt}_{file_ending}.pdf", dpi = 200)

B.print_correlations(filter = filt)


fig, ax = B.plot_worst_lightcurve(filter = filt)
fig.savefig(f"./figures/worst_lightcurve_{filt}_{file_ending}.pdf", dpi = 200)


if metric_name == "$\\mathcal{L}_\infty$":
fig, ax = B.plot_error_distribution(filt)
fig.savefig(f"./benchmarks/{name}/error_distribution_{filt}.pdf", dpi = 200)


fig, ax = B.plot_worst_lightcurves()
fig.savefig(f"./benchmarks/{name}/worst_lightcurves_{file_ending}.pdf", dpi = 200)


2 changes: 1 addition & 1 deletion examples/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1 @@
figures/
results_*.npz
27 changes: 27 additions & 0 deletions examples/GRB/bash.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/bin/bash -l
#Set job requirements
#SBATCH -N 1
#SBATCH -n 1
#SBATCH -p gpu
#SBATCH -t 00:30:00
#SBATCH --gpus-per-node=1
#SBATCH --cpus-per-gpu=1
#SBATCH --mem-per-gpu=5G
#SBATCH --output=outdir_GRB170817_tophat/log.out
#SBATCH --job-name=GRB170817

now=$(date)
echo "$now"

# Loading modules
# module load 2024
# module load Python/3.10.4-GCCcore-11.3.0
conda activate /home/twouters2/miniconda3/envs/ninjax

# Display GPU name
nvidia-smi --query-gpu=name --format=csv,noheader

# Run the script
python run_GRB170817_tophat.py

echo "DONE"
213 changes: 213 additions & 0 deletions examples/GRB/injection_gaussian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
"""Injection runs with afterglowpy gaussian"""

import os
import jax
print(f"GPU found? {jax.devices()}")
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
import numpy as np
import matplotlib.pyplot as plt
import corner

from fiesta.inference.lightcurve_model import AfterglowpyPCA, PCALightcurveModel
from fiesta.inference.injection import InjectionRecoveryAfterglowpy
from fiesta.inference.likelihood import EMLikelihood
from fiesta.inference.prior import Uniform, CompositePrior, Constraint
from fiesta.inference.prior_dict import ConstrainedPrior
from fiesta.inference.fiesta import Fiesta
from fiesta.utils import load_event_data, write_event_data

import time
start_time = time.time()

################
### Preamble ###
################

jax.config.update("jax_enable_x64", True)

params = {"axes.grid": True,
"text.usetex" : True,
"font.family" : "serif",
"ytick.color" : "black",
"xtick.color" : "black",
"axes.labelcolor" : "black",
"axes.edgecolor" : "black",
"font.serif" : ["Computer Modern Serif"],
"xtick.labelsize": 16,
"ytick.labelsize": 16,
"axes.labelsize": 16,
"legend.fontsize": 16,
"legend.title_fontsize": 16,
"figure.titlesize": 16}

plt.rcParams.update(params)

default_corner_kwargs = dict(bins=40,
smooth=1.,
label_kwargs=dict(fontsize=16),
title_kwargs=dict(fontsize=16),
color="blue",
# quantiles=[],
# levels=[0.9],
plot_density=True,
plot_datapoints=False,
fill_contours=True,
max_n_ticks=4,
min_n_ticks=3,
save=False,
truth_color="red")


##############
### MODEL ###
##############

name = "gaussian"
model_dir = f"../../flux_models/afterglowpy_{name}/model"
FILTERS = ["radio-3GHz", "radio-6GHz", "X-ray-1keV", "bessellv"]

model = AfterglowpyPCA(name,
model_dir,
filters = FILTERS)


###################
### INJECT ###
### AFTERGLOWPY ###
###################

trigger_time = 58849 # 01-01-2020 in mjd
remake_injection = False
injection_dict = {"inclination_EM": 0.174, "log10_E0": 54.4, "thetaCore": 0.14, "alphaWing": 3, "p": 2.6, "log10_n0": -2, "log10_epsilon_e": -2.06, "log10_epsilon_B": -4.2, "luminosity_distance": 40.0}

if remake_injection:
injection = InjectionRecoveryAfterglowpy(injection_dict, jet_type = 0, filters = FILTERS, N_datapoints = 70, error_budget = 0.5, tmin = 1, tmax = 2000, trigger_time = trigger_time)
injection.create_injection()
data = injection.data
write_event_data("./injection_gaussian/injection_gaussian.dat", data)

data = load_event_data("./injection_gaussian/injection_gaussian.dat")
#############################
### PRIORS AND LIKELIHOOD ###
#############################

inclination_EM = Uniform(xmin=0.0, xmax=np.pi/2, naming=['inclination_EM'])
log10_E0 = Uniform(xmin=47.0, xmax=57.0, naming=['log10_E0'])
thetaCore = Uniform(xmin=0.01, xmax=np.pi/5, naming=['thetaCore'])
alphaWing = Uniform(xmin = 0.2, xmax = 3.5, naming= ["alphaWing"])
thetaWing = Constraint(xmin = 0, xmax = np.pi/2, naming = ["thetaWing"])
log10_n0 = Uniform(xmin=-6.0, xmax=2.0, naming=['log10_n0'])
p = Uniform(xmin=2.01, xmax=3.0, naming=['p'])
log10_epsilon_e = Uniform(xmin=-4.0, xmax=0.0, naming=['log10_epsilon_e'])
log10_epsilon_B = Uniform(xmin=-8.0, xmax=0.0, naming=['log10_epsilon_B'])
epsilon_tot = Constraint(xmin = 0, xmax = 1, naming = ["epsilon_tot"])

# luminosity_distance = Uniform(xmin=30.0, xmax=50.0, naming=['luminosity_distance'])
def conversion_function(sample):
converted_sample = sample
converted_sample["thetaWing"] = converted_sample["thetaCore"] * converted_sample["alphaWing"]
converted_sample["epsilon_tot"] = 10**(converted_sample["log10_epsilon_B"]) + 10**(converted_sample["log10_epsilon_e"])
return converted_sample

prior_list = [inclination_EM,
log10_E0,
thetaCore,
alphaWing,
log10_n0,
p,
log10_epsilon_e,
log10_epsilon_B,
thetaWing,
epsilon_tot]

prior = ConstrainedPrior(prior_list, conversion_function)

detection_limit = None
likelihood = EMLikelihood(model,
data,
FILTERS,
tmax = 2000.0,
trigger_time=trigger_time,
detection_limit = detection_limit,
fixed_params={"luminosity_distance": 40.0},
error_budget = 1e-5)


##############
### FIESTA ###
##############

mass_matrix = jnp.eye(prior.n_dim)
eps = 5e-3
local_sampler_arg = {"step_size": mass_matrix * eps}

# Save for postprocessing
outdir = f"./injection_{name}/"
if not os.path.exists(outdir):
os.makedirs(outdir)

fiesta = Fiesta(likelihood,
prior,
n_chains = 1_000,
n_loop_training = 7,
n_loop_production = 3,
num_layers = 4,
hidden_size = [64, 64],
n_epochs = 20,
n_local_steps = 50,
n_global_steps = 200,
local_sampler_arg=local_sampler_arg,
outdir = outdir)

fiesta.sample(jax.random.PRNGKey(42))

fiesta.print_summary()

name = outdir + f'results_training.npz'
print(f"Saving samples to {name}")
state = fiesta.Sampler.get_sampler_state(training=True)
chains, log_prob, local_accs, global_accs, loss_vals = state["chains"], state[
"log_prob"], state["local_accs"], state["global_accs"], state["loss_vals"]
local_accs = jnp.mean(local_accs, axis=0)
global_accs = jnp.mean(global_accs, axis=0)
np.savez(name, log_prob=log_prob, local_accs=local_accs,
global_accs=global_accs, loss_vals=loss_vals)

# - production phase
name = outdir + f'results_production.npz'
print(f"Saving samples to {name}")
state = fiesta.Sampler.get_sampler_state(training=False)
chains, log_prob, local_accs, global_accs = state["chains"], state[
"log_prob"], state["local_accs"], state["global_accs"]
local_accs = jnp.mean(local_accs, axis=0)
global_accs = jnp.mean(global_accs, axis=0)
np.savez(name, chains=chains, log_prob=log_prob,
local_accs=local_accs, global_accs=global_accs)

################
### PLOTTING ###
################
# Fixed names: do not include them in the plotting, as will break corner
parameter_names = prior.naming
truths = [injection_dict[key] for key in parameter_names]

n_chains, n_steps, n_dim = np.shape(chains)
samples = np.reshape(chains, (n_chains * n_steps, n_dim))
samples = np.asarray(samples) # convert from jax.numpy array to numpy array for corner consumption

corner.corner(samples, labels = parameter_names, hist_kwargs={'density': True}, truths = truths, **default_corner_kwargs)
plt.savefig(os.path.join(outdir, "corner.png"), bbox_inches = 'tight')
plt.close()

end_time = time.time()
runtime_seconds = end_time - start_time
number_of_minutes = runtime_seconds // 60
number_of_seconds = np.round(runtime_seconds % 60, 2)
print(f"Total runtime: {number_of_minutes} m {number_of_seconds} s")

print("Plotting lightcurves")
fiesta.plot_lightcurves()
print("Plotting lightcurves . . . done")

print("DONE")
1 change: 1 addition & 0 deletions examples/GRB/injection_gaussian/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

Binary file added examples/GRB/injection_gaussian/corner.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
70 changes: 70 additions & 0 deletions examples/GRB/injection_gaussian/injection_gaussian.dat
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
2020-01-02T00:00:00.000 radio-3GHz 6.473406 0.500000
2020-04-10T22:48:00.000 radio-3GHz 12.246090 0.500000
2020-07-19T21:36:00.000 radio-3GHz 13.860922 0.500000
2020-10-27T20:24:00.000 radio-3GHz 15.145347 0.500000
2021-02-04T19:12:00.000 radio-3GHz 16.200114 0.500000
2021-05-15T18:00:00.000 radio-3GHz 17.037075 0.500000
2021-08-23T16:48:00.000 radio-3GHz 17.709738 0.500000
2021-12-01T15:36:00.000 radio-3GHz 18.265443 0.500000
2022-03-11T14:24:00.000 radio-3GHz 18.738014 0.500000
2022-06-19T13:12:00.000 radio-3GHz 19.142404 0.500000
2022-09-27T12:00:00.000 radio-3GHz 19.496765 0.500000
2023-01-05T10:48:00.000 radio-3GHz 19.810069 0.500000
2023-04-15T09:36:00.000 radio-3GHz 20.091715 0.500000
2023-07-24T08:24:00.000 radio-3GHz 20.346828 0.500000
2023-11-01T07:12:00.000 radio-3GHz 20.578875 0.500000
2024-02-09T06:00:00.000 radio-3GHz 20.791826 0.500000
2024-05-19T04:48:00.000 radio-3GHz 20.989639 0.500000
2024-08-27T03:36:00.000 radio-3GHz 21.171644 0.500000
2024-12-05T02:24:00.000 radio-3GHz 21.341979 0.500000
2025-03-15T01:12:00.000 radio-3GHz 21.500904 0.500000
2025-06-23T00:00:00.000 radio-3GHz 21.652456 0.500000
2020-01-02T00:00:00.000 radio-6GHz 6.642397 0.500000
2020-05-05T22:30:00.000 radio-6GHz 13.306995 0.500000
2020-09-07T21:00:00.000 radio-6GHz 15.133077 0.500000
2021-01-10T19:30:00.000 radio-6GHz 16.559615 0.500000
2021-05-15T18:00:00.000 radio-6GHz 17.639135 0.500000
2021-09-17T16:30:00.000 radio-6GHz 18.460232 0.500000
2022-01-20T15:00:00.000 radio-6GHz 19.112092 0.500000
2022-05-25T13:30:00.000 radio-6GHz 19.646921 0.500000
2022-09-27T12:00:00.000 radio-6GHz 20.098825 0.500000
2023-01-30T10:30:00.000 radio-6GHz 20.487108 0.500000
2023-06-04T09:00:00.000 radio-6GHz 20.824446 0.500000
2023-10-07T07:30:00.000 radio-6GHz 21.125451 0.500000
2024-02-09T06:00:00.000 radio-6GHz 21.393886 0.500000
2024-06-13T04:30:00.000 radio-6GHz 21.637560 0.500000
2024-10-16T03:00:00.000 radio-6GHz 21.859776 0.500000
2025-02-18T01:30:00.000 radio-6GHz 22.064226 0.500000
2025-06-23T00:00:00.000 radio-6GHz 22.254516 0.500000
2020-01-02T00:00:00.000 X-ray-1keV 21.209144 0.500000
2020-04-06T04:34:17.143 X-ray-1keV 27.962903 0.500000
2020-07-10T09:08:34.286 X-ray-1keV 29.537845 0.500000
2020-10-13T13:42:51.429 X-ray-1keV 30.789609 0.500000
2021-01-16T18:17:08.571 X-ray-1keV 31.828374 0.500000
2021-04-21T22:51:25.714 X-ray-1keV 32.666384 0.500000
2021-07-26T03:25:42.857 X-ray-1keV 33.344842 0.500000
2021-10-29T08:00:00.000 X-ray-1keV 33.904345 0.500000
2022-02-01T12:34:17.143 X-ray-1keV 34.378401 0.500000
2022-05-07T17:08:34.286 X-ray-1keV 34.787830 0.500000
2022-08-10T21:42:51.429 X-ray-1keV 35.144792 0.500000
2022-11-14T02:17:08.571 X-ray-1keV 35.464131 0.500000
2023-02-17T06:51:25.714 X-ray-1keV 35.747110 0.500000
2023-05-23T11:25:42.857 X-ray-1keV 36.004242 0.500000
2023-08-26T16:00:00.000 X-ray-1keV 36.238919 0.500000
2023-11-29T20:34:17.143 X-ray-1keV 36.454925 0.500000
2024-03-04T01:08:34.286 X-ray-1keV 36.654020 0.500000
2024-06-07T05:42:51.429 X-ray-1keV 36.837070 0.500000
2024-09-10T10:17:08.571 X-ray-1keV 37.010718 0.500000
2024-12-14T14:51:25.714 X-ray-1keV 37.169814 0.500000
2025-03-19T19:25:42.857 X-ray-1keV 37.321108 0.500000
2025-06-23T00:00:00.000 X-ray-1keV 37.465122 0.500000
2020-01-02T00:00:00.000 bessellv 15.916906 0.500000
2020-08-11T02:40:00.000 bessellv 24.685939 0.500000
2021-03-21T05:20:00.000 bessellv 27.115049 0.500000
2021-10-29T08:00:00.000 bessellv 28.612108 0.500000
2022-06-08T10:40:00.000 bessellv 29.619683 0.500000
2023-01-16T13:20:00.000 bessellv 30.363185 0.500000
2023-08-26T16:00:00.000 bessellv 30.946682 0.500000
2024-04-04T18:40:00.000 bessellv 31.423505 0.500000
2024-11-12T21:20:00.000 bessellv 31.826433 0.500000
2025-06-23T00:00:00.000 bessellv 32.172885 0.500000
Binary file added examples/GRB/injection_gaussian/lightcurves.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 187c291

Please sign in to comment.