Skip to content

Commit

Permalink
permit overriding min_radius_px from top-level callers
Browse files Browse the repository at this point in the history
It's currently necessary to override this parameter to make galsim unit tests
pass (see jeff-regier#534).

This is the simplest way to code it but it's ugly in my opinion. We should
discuss alternatives in the PR thread or an issue thread.
  • Loading branch information
gostevehoward committed Jan 25, 2017
1 parent 6370d74 commit 9948edb
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 18 deletions.
5 changes: 3 additions & 2 deletions src/DeterministicVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ Arguments:
"""
function infer_source(images::Vector{Image},
neighbors::Vector{CatalogEntry},
entry::CatalogEntry)
entry::CatalogEntry;
min_radius_pix=Nullable{Float64}())
if length(neighbors) > 100
Log.warn("Excessive number ($(length(neighbors))) of neighbors")
end
Expand All @@ -122,7 +123,7 @@ function infer_source(images::Vector{Image},
cat_local = vcat([entry], neighbors)
vp = init_sources([1], cat_local)
patches = Infer.get_sky_patches(images, cat_local)
Infer.load_active_pixels!(images, patches)
Infer.load_active_pixels!(images, patches, min_radius_pix=min_radius_pix)

ea = ElboArgs(images, vp, patches, [1])
f_evals, max_f, max_x, nm_result = maximize_f(elbo, ea)
Expand Down
5 changes: 3 additions & 2 deletions src/DeterministicVIImagePSF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@ export elbo_likelihood_with_fft!, FSMSensitiveFloatMatrices,

function infer_source_fft(images::Vector{Image},
neighbors::Vector{CatalogEntry},
entry::CatalogEntry)
entry::CatalogEntry;
min_radius_pix=Nullable{Float64}())
cat_local = vcat([entry], neighbors)
vp = init_sources([1], cat_local)
patches = get_sky_patches(images, cat_local)
load_active_pixels!(images, patches)
load_active_pixels!(images, patches, min_radius_pix=min_radius_pix)

ea_fft, fsm_mat = initialize_fft_elbo_parameters(
images, vp, patches, [1], use_raw_psf=false)
Expand Down
15 changes: 13 additions & 2 deletions src/GalsimBenchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import Celeste.ParallelRun: one_node_single_infer, one_node_joint_infer

const GALSIM_BENCHMARK_DIR = joinpath(Pkg.dir("Celeste"), "benchmark", "galsim")
const LATEST_FITS_FILENAME_DIR = joinpath(GALSIM_BENCHMARK_DIR, "latest_filenames")
const ACTIVE_PIXELS_MIN_RADIUS_PX = Nullable(40.0)

type GalsimFitsFileNotFound <: Exception end

Expand Down Expand Up @@ -308,6 +309,10 @@ end
# * error_sds: absolute error of estimate, divided by posterior standard deviation
function run_benchmarks(; test_case_names=String[], print_fn=println, joint_inference=false,
infer_source_callback=DeterministicVI.infer_source)
function infer_source_min_radius(args...; kwargs...)
infer_source_callback(args...; min_radius_pix=ACTIVE_PIXELS_MIN_RADIUS_PX, kwargs...)
end

extensions, wcs = load_galsim_fits("galsim_benchmarks")
all_benchmark_data = []
for test_case_index in 1:div(length(extensions), 5)
Expand All @@ -326,7 +331,13 @@ function run_benchmarks(; test_case_names=String[], print_fn=println, joint_infe
if joint_inference
target_sources = collect(1:num_sources)
neighbor_map = Infer.find_neighbors(target_sources, catalog, images)
results = one_node_joint_infer(catalog, target_sources, neighbor_map, images)
results = one_node_joint_infer(
catalog,
target_sources,
neighbor_map,
images,
min_radius_pix=ACTIVE_PIXELS_MIN_RADIUS_PX,
)
inferred_params = [results[source_index].vs for source_index in 1:num_sources]
else
for source_index in 1:num_sources
Expand All @@ -337,7 +348,7 @@ function run_benchmarks(; test_case_names=String[], print_fn=println, joint_infe
target_sources,
neighbor_map,
images,
infer_source_callback=infer_source_callback,
infer_source_callback=infer_source_min_radius,
)
@assert length(results) == 1
push!(inferred_params, results[1].vs)
Expand Down
11 changes: 3 additions & 8 deletions src/Infer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,9 @@ function find_neighbors(target_sources::Vector{Int64},
end


"""
noise_fraction: The proportion of the noise below which we will remove pixels.
min_radius_pix: A minimum pixel radius to be included.
"""
function get_sky_patches(images::Vector{Image},
catalog::Vector{CatalogEntry};
radius_override_pix=NaN,
noise_fraction=0.1,
min_radius_pix=8.0)
radius_override_pix=NaN)
N = length(images)
S = length(catalog)
patches = Array(SkyPatch, S, N)
Expand Down Expand Up @@ -141,7 +135,8 @@ function load_active_pixels!(images::Vector{Image},
patches::Matrix{SkyPatch};
exclude_nan=true,
noise_fraction=0.5,
min_radius_pix=8.0)
min_radius_pix=Nullable{Float64}())
min_radius_pix = get(min_radius_pix, 8.0)
S, N = size(patches)

for n = 1:N, s=1:S
Expand Down
5 changes: 3 additions & 2 deletions src/deterministic_vi_image_psf/elbo_image_psf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -365,10 +365,11 @@ function initialize_fft_elbo_parameters(
active_sources::Vector{Int};
use_raw_psf=true,
use_trimmed_psf=true,
allocate_fsm_mat=true)
allocate_fsm_mat=true,
min_radius_pix=Nullable{Float64}())

ea = ElboArgs(images, vp, patches, active_sources, psf_K=1)
load_active_pixels!(images, ea.patches; exclude_nan=false)
load_active_pixels!(images, ea.patches; exclude_nan=false, min_radius_pix=min_radius_pix)

fsm_mat = nothing
if allocate_fsm_mat
Expand Down
5 changes: 3 additions & 2 deletions src/joint_infer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ function one_node_joint_infer(catalog, target_sources, neighbor_map, images;
batch_size=60,
within_batch_shuffling=true,
n_iters=3,
use_default_optim_params=true)
use_default_optim_params=true,
min_radius_pix=Nullable{Float64}())
# Seed random number generator to ensure the same results per run.
srand(42)

Expand Down Expand Up @@ -279,7 +280,7 @@ function one_node_joint_infer(catalog, target_sources, neighbor_map, images;
ids_local = vcat([entry_id], neighbor_ids)

patches = Infer.get_sky_patches(images, cat_local)
Infer.load_active_pixels!(images, patches)
Infer.load_active_pixels!(images, patches, min_radius_pix=min_radius_pix)

# Load vp with shared target source params, and also vp
# that doesn't share target source params
Expand Down

0 comments on commit 9948edb

Please sign in to comment.