Skip to content

Commit

Permalink
Reproducible results scripts (#43)
Browse files Browse the repository at this point in the history
* reorgnize somse stuff

* more stuff to organize

* refactor

* figures for exp1

* exp1 readme

* scripts to obtain figures for exp1

* correction

* generalize so the same function works in fixed flux case and not

* use generalized version

* set of figures for experiment 1

* corrections add multiplicative bias histogram

* being careful

* the angle domain should be between (0, pi)

* added new test on scaling of scatter

* name as a test

* fix test

* command to make functions

* remove line

* sometimes it is useful to split them up

* when inferring f, need to pop lf

* return info in separate functions

* absorb into single script via booleans

* this is how we fixed certain parameters during inference

* draft of exp 2 with galaxies between snr (8, 100)

* correction

* dictionary missing

* small corrections and additions for final script

* more experiments

* adapt info might be interesteing

* check std

* add shell script

* will berak it

* ruff

* draft figures

* add timing results

* reorganize a b it

* improvements

* improvements

* first set of figures, might run for more samples and see if 2000 chains breaks it

* determinism and outliers

* seed

* determinism elsewhere

* readme

* might need to revisit adding this kwarg

* a little more data could be interesting

* redo script, now fix everything except shapes

* rename

* move script

* expanding details

* take long enough that slurm script is helpful

* need to overwrite by default

* figs exp 3

* notebooks

* need more time

* reduce shape

* notebooks

* small note
  • Loading branch information
ismael-mendoza authored Nov 21, 2024
1 parent 0409d02 commit 68e15c8
Show file tree
Hide file tree
Showing 40 changed files with 1,932 additions and 192 deletions.
49 changes: 49 additions & 0 deletions bpd/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,55 @@ def one_step(state, rng_key):
return (states, infos)


def run_warmup_nuts(
rng_key: PRNGKeyArray,
init_positions: ArrayLike,
data: ArrayLike,
*,
logtarget: Callable,
initial_step_size: float,
max_num_doublings: int,
n_warmup_steps: int = 500,
is_mass_matrix_diagonal: bool = True,
target_acceptance_rate: float = 0.8,
):
_logtarget = partial(logtarget, data=data)
warmup = blackjax.window_adaptation(
blackjax.nuts,
_logtarget,
progress_bar=False,
is_mass_matrix_diagonal=is_mass_matrix_diagonal,
max_num_doublings=max_num_doublings,
initial_step_size=initial_step_size,
target_acceptance_rate=target_acceptance_rate,
)

(init_states, tuned_params), adapt_info = warmup.run(
rng_key, init_positions, n_warmup_steps
)
return init_states, tuned_params, adapt_info


def run_sampling_nuts(
rng_key: PRNGKeyArray,
init_states: ArrayLike,
tuned_params: dict,
data: ArrayLike,
*,
logtarget: Callable,
n_samples: int,
max_num_doublings=5,
):
_logtarget = partial(logtarget, data=data)
kernel = blackjax.nuts(
_logtarget, **tuned_params, max_num_doublings=max_num_doublings
).step
states, info = inference_loop(
rng_key, init_states, kernel=kernel, n_samples=n_samples
)
return states.position, info


def run_inference_nuts(
rng_key: PRNGKeyArray,
init_positions: ArrayLike,
Expand Down
1 change: 0 additions & 1 deletion bpd/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def load_dataset(fpath: str) -> dict[str, Array]:
assert Path(fpath).suffix == ".npz"

ds = {}

npzfile = jnp.load(fpath)
for k in npzfile.files:
ds[k] = npzfile[k]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
"""In this script, we fix the flux and HLR to truth when doing fits."""

from functools import partial
from typing import Callable

Expand All @@ -10,15 +8,15 @@
from jax.scipy import stats

from bpd.chains import run_inference_nuts
from bpd.draw import draw_gaussian, draw_gaussian_galsim
from bpd.draw import draw_gaussian_galsim
from bpd.noise import add_noise
from bpd.prior import ellip_mag_prior, sample_ellip_prior, scalar_shear_transformation


def get_target_galaxy_params_simple(
def sample_target_galaxy_params_simple(
rng_key: PRNGKeyArray,
*,
shape_noise: float = 1e-3,
shape_noise: float,
g1: float = 0.02,
g2: float = 0.0,
):
Expand All @@ -42,6 +40,66 @@ def get_target_galaxy_params_simple(
}


# interim prior
def logprior(
params: dict[str, Array],
*,
sigma_e: float,
sigma_x: float = 0.5, # pixels
flux_bds: tuple = (-1.0, 9.0),
hlr_bds: tuple = (0.01, 5.0),
all_free: bool = True,
) -> Array:
prior = jnp.array(0.0)

if all_free:
f1, f2 = flux_bds
prior += stats.uniform.logpdf(params["lf"], f1, f2 - f1)

h1, h2 = hlr_bds
prior += stats.uniform.logpdf(params["hlr"], h1, h2 - h1)

# NOTE: hard-coded assumption that galaxy is in center-pixel within odd-size image.
# sigma_x in units of pixels.
prior += stats.norm.logpdf(params["x"], loc=0.0, scale=sigma_x)
prior += stats.norm.logpdf(params["y"], loc=0.0, scale=sigma_x)

e_mag = jnp.sqrt(params["e1"] ** 2 + params["e2"] ** 2)
prior += jnp.log(ellip_mag_prior(e_mag, sigma=sigma_e))

return prior


def loglikelihood(
params: dict[str, Array],
data: Array,
*,
draw_fnc: Callable,
background: float,
free_flux: bool = True,
):
# NOTE: draw_fnc should already contain `f` and `hlr` as constant arguments.
_draw_params = {**{"g1": 0.0, "g2": 0.0}, **params} # function is more general

if free_flux:
_draw_params["f"] = 10 ** _draw_params.pop("lf")
model = draw_fnc(**_draw_params)

likelihood_pp = stats.norm.logpdf(data, loc=model, scale=jnp.sqrt(background))
likelihood = jnp.sum(likelihood_pp)
return likelihood


def logtarget(
params: dict[str, Array],
data: Array,
*,
logprior_fnc: Callable,
loglikelihood_fnc: Callable,
):
return logprior_fnc(params) + loglikelihood_fnc(params, data)


def get_true_params_from_galaxy_params(galaxy_params: dict[str, Array]):
true_params = {**galaxy_params}
e1, e2 = true_params.pop("e1"), true_params.pop("e2")
Expand Down Expand Up @@ -89,70 +147,32 @@ def get_target_images(
return jnp.concatenate(target_images, axis=0)


# interim prior
def logprior(
params: dict[str, Array], *, sigma_e: float, sigma_x: float = 0.5
) -> Array:
prior = jnp.array(0.0)

e_mag = jnp.sqrt(params["e1"] ** 2 + params["e2"] ** 2)
prior += jnp.log(ellip_mag_prior(e_mag, sigma=sigma_e))

# NOTE: hard-coded assumption that galaxy is in center-pixel within odd-size image.
# sigma_x in units of pixels.
prior += stats.norm.logpdf(params["x"], loc=0.0, scale=sigma_x)
prior += stats.norm.logpdf(params["y"], loc=0.0, scale=sigma_x)

return prior


def loglikelihood(
params: dict[str, Array], data: Array, *, draw_fnc: Callable, background: float
):
# NOTE: draw_fnc should already contain `f` and `hlr` as constant arguments.
_draw_params = {**{"g1": 0.0, "g2": 0.0}, **params} # function is more general
model = draw_fnc(**_draw_params)

likelihood_pp = stats.norm.logpdf(data, loc=model, scale=jnp.sqrt(background))
likelihood = jnp.sum(likelihood_pp)
return likelihood


def logtarget(
params: dict[str, Array],
data: Array,
*,
logprior_fnc: Callable,
loglikelihood_fnc: Callable,
):
return logprior_fnc(params) + loglikelihood_fnc(params, data)


def pipeline_image_interim_samples_one_galaxy(
def pipeline_interim_samples_one_galaxy(
rng_key: PRNGKeyArray,
true_params: dict[str, float],
target_image: Array,
fixed_draw_kwargs: dict,
*,
initialization_fnc: Callable,
draw_fnc: Callable,
logprior: Callable,
sigma_e_int: float,
f: float,
hlr: float,
n_samples: int = 100,
max_num_doublings: int = 5,
initial_step_size: float = 1e-3,
n_warmup_steps: int = 500,
is_mass_matrix_diagonal: bool = True,
slen: int = 53,
fft_size: int = 256,
background: float = 1.0,
free_flux: bool = True,
):
# Flux and HLR are fixed to truth and not inferred in this function.
k1, k2 = random.split(rng_key)

init_position = initialization_fnc(k1, true_params=true_params, data=target_image)

_draw_fnc = partial(draw_gaussian, f=f, hlr=hlr, slen=slen, fft_size=fft_size)
_loglikelihood = partial(loglikelihood, draw_fnc=_draw_fnc, background=background)
_draw_fnc = partial(draw_fnc, **fixed_draw_kwargs)
_loglikelihood = partial(
loglikelihood, draw_fnc=_draw_fnc, background=background, free_flux=free_flux
)
_logprior = partial(logprior, sigma_e=sigma_e_int)

_logtarget = partial(
Expand Down
2 changes: 1 addition & 1 deletion bpd/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def sample_ellip_prior(rng_key, sigma: float, n: int = 1):
"""Sample n ellipticities isotropic components with Gary's prior from magnitude."""
key1, key2 = random.split(rng_key, 2)
e_mag = sample_mag_ellip_prior(key1, sigma=sigma, n=n)
e_phi = random.uniform(key2, shape=(n,), minval=0, maxval=2 * jnp.pi)
e_phi = random.uniform(key2, shape=(n,), minval=0, maxval=jnp.pi)
e1 = e_mag * jnp.cos(2 * e_phi)
e2 = e_mag * jnp.sin(2 * e_phi)
return jnp.stack((e1, e2), axis=1)
Expand Down
11 changes: 11 additions & 0 deletions experiments/exp1/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Experiment 1

This folder contains scripts to reproduce results on the first experiment. The scope of this
experiment is posteriors, biases, and calibration on toy ellipticities in the low noise setting.

The convergence tests are in the form of unit tests for this experiment in `test_convergence.py`

## Scripts

* `get_posteriors.sh`: Get 1000 shear posteriors using the slurm script.
* `get_figures_shear.sh`: Obtain all relevant diagnostic plots and save to `figs` folder.
Binary file added experiments/exp1/figs/calibration.pdf
Binary file not shown.
Binary file added experiments/exp1/figs/contours.pdf
Binary file not shown.
Binary file not shown.
Binary file added experiments/exp1/figs/traces.pdf
Binary file not shown.
2 changes: 2 additions & 0 deletions experiments/exp1/get_figures.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#!/bin/bash
./make_figures.py
2 changes: 2 additions & 0 deletions experiments/exp1/get_posteriors.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#!/bin/bash
../../scripts/slurm_toy_shear_vectorized.py 42 toy_shear_42
110 changes: 110 additions & 0 deletions experiments/exp1/make_figures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#!/usr/bin/env python3

import os

os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["JAX_PLATFORMS"] = "cpu"
os.environ["JAX_ENABLE_X64"] = "True"

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from jax import Array
from matplotlib.backends.backend_pdf import PdfPages
from tqdm import tqdm

from bpd import DATA_DIR
from bpd.diagnostics import get_contour_plot, get_gauss_pc_fig, get_pc_fig


def make_trace_plots(g_samples: Array, n_examples: int = 10) -> None:
"""Make example figure showing example trace plots of shear posteriors."""
# by default, we choose 10 random traces to plot in 1 PDF file.
fname = "figs/traces.pdf"
with PdfPages(fname) as pdf:
assert g_samples.ndim == 3
n_post = g_samples.shape[0]
indices = np.random.choice(np.arange(n_post), (n_examples,))

for ii in tqdm(indices, desc="Saving traces"):
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 5))
g1 = g_samples[ii, :, 0]
g2 = g_samples[ii, :, 0]

ax1.plot(g1)
ax2.plot(g2)
ax1.set_title(f"Index: {ii}")

pdf.savefig(fig)
plt.close(fig)


def make_contour_plots(g_samples: Array, n_examples=10) -> None:
"""Make example figure showing example contour plots of shear posterios"""
# by default, we choose 10 random contours to plot in 1 PDF file.
fname = "figs/contours.pdf"
with PdfPages(fname) as pdf:
assert g_samples.ndim == 3
n_post = g_samples.shape[0]
indices = np.random.choice(np.arange(n_post), (n_examples,))

truth = {"g1": 0.02, "g2": 0.0}

for ii in tqdm(indices, desc="Saving contours"):
g_dict = {"g1": g_samples[ii, :, 0], "g2": g_samples[ii, :, 1]}
fig = get_contour_plot([g_dict], [f"post_{ii}"], truth)
plt.suptitle(f"Index: {ii}")
pdf.savefig(fig)
plt.close(fig)


def make_posterior_calibration(g_samples: Array) -> None:
"""Output posterior calibration figure."""
# make two types, assuming gaussianity and one not assuming gaussianity.
fname = "figs/calibration.pdf"
with PdfPages(fname) as pdf:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
get_gauss_pc_fig(ax1, g_samples[..., 0], truth=0.02, param_name="g1 (gauss)")
get_pc_fig(ax2, g_samples[..., 0], truth=0.02, param_name="g1 (full)")
pdf.savefig(fig)
plt.close(fig)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
get_gauss_pc_fig(ax1, g_samples[..., 1], truth=0.0, param_name="g2 (gauss)")
get_pc_fig(ax2, g_samples[..., 1], truth=0.0, param_name="g2 (full)")
pdf.savefig(fig)
plt.close(fig)


def make_histogram_mbias(g_samples: Array) -> None:
fname = "figs/multiplicative_bias_hist.pdf"
with PdfPages(fname) as pdf:
g1 = g_samples[:, :, 0]
mbias = (g1.mean(axis=1) - 0.02) / 0.02
fig, ax = plt.subplots(1, 1, figsize=(7, 7))
ax.hist(mbias, bins=31, histtype="step")

pdf.savefig(fig)
plt.close(fig)


def main():
pdir = DATA_DIR / "cache_chains" / "toy_shear_42"
assert pdir.exists()
all_g_samples = []
for fpath in pdir.iterdir():
if "g_samples" in fpath.name:
_g_samples = jnp.load(fpath)
all_g_samples.append(_g_samples)
g_samples = jnp.concatenate(all_g_samples, axis=0)
assert g_samples.shape == (1000, 3000, 2)

# make plots
make_trace_plots(g_samples)
make_contour_plots(g_samples)
make_posterior_calibration(g_samples)
make_histogram_mbias(g_samples)


if __name__ == "__main__":
main()
Loading

0 comments on commit 68e15c8

Please sign in to comment.