Skip to content

Commit

Permalink
add Config object to faciliate passing params deep into the call chain
Browse files Browse the repository at this point in the history
As discucsed in PR jeff-regier#538, we have this problem of passing arguments deep down the call stack of
Celeste inference. This manifests itself as lots of chains of keyword arguments and various
not-so-pretty hacks (like custom `infer_source_callback` wrappers) to set options.

The discussion in jeff-regier#538 outlines a broad direction for not only solving this problem, but also the
related problem of different branches of code for different options at each stage of inference. The
two current examples are single vs. joint inference and MOG vs. FFT PSF inference, but there will
likely always be such branches. I still think that's a good direction to go, but even a minimal
implementation is too invasive to be safe/worthwhile right now.

Instead, this commit takes what I view as the simplest possible step in the right direction. We add
a single `Config` object which is instantiated at top-level and passed as the first argument through
all major function calls. Over time this object ought to absorb more parameters. Once there are
distinct groups of parameters, the object can be split into multiple objects, and functions can
receive only the sub-Config-object they require. Dynamic dispatch of different types of
sub-Config-objects can follow from that. (If this doesn't make sense, read the example in jeff-regier#538.)

The change looks a little silly right now, since it's this config object with one parameter, but
it's an incremental path. There are a lot of legacy wrappers included which instantiate the default
Config object, so I don't have to go change every top-level caller (in which case I'd likely break
something because I'm not familiar with much of the code), but over time callers ought to be changed
to explicitly pass a Config object, and the legacy wrappers can go away.
  • Loading branch information
gostevehoward committed Mar 3, 2017
1 parent 15a3605 commit 31b704a
Show file tree
Hide file tree
Showing 15 changed files with 179 additions and 53 deletions.
2 changes: 2 additions & 0 deletions benchmark/accuracy/run_celeste_on_field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import ArgParse
using DataFrames

import Celeste.AccuracyBenchmark
import Celeste.Configs
import Celeste.Infer
import Celeste.ParallelRun
import Celeste.SDSSIO
Expand Down Expand Up @@ -60,6 +61,7 @@ end
neighbor_map = Infer.find_neighbors(target_sources, catalog_entries, images)

results = AccuracyBenchmark.run_celeste(
Configs.Config(),
catalog_entries,
target_sources,
images,
Expand Down
6 changes: 5 additions & 1 deletion src/AccuracyBenchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import FITSIO
import StaticArrays
import WCS

import Celeste.Configs
import Celeste.DeterministicVI
import Celeste.DeterministicVIImagePSF
import Celeste.Infer
Expand Down Expand Up @@ -743,11 +744,13 @@ end

# Run Celeste with any combination of single/joint inference and MOG/FFT model
function run_celeste(
catalog_entries, target_sources, images; use_joint_inference=false, use_fft=false
config::Configs.Config, catalog_entries, target_sources, images;
use_joint_inference=false, use_fft=false
)
neighbor_map = Infer.find_neighbors(target_sources, catalog_entries, images)
if use_joint_inference
ParallelRun.one_node_joint_infer(
config,
catalog_entries,
target_sources,
neighbor_map,
Expand All @@ -761,6 +764,7 @@ function run_celeste(
infer_source_callback = DeterministicVI.infer_source
end
ParallelRun.one_node_single_infer(
config,
catalog_entries,
target_sources,
neighbor_map,
Expand Down
1 change: 1 addition & 0 deletions src/Celeste.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ __precompile__()
module Celeste

# submodules
include("Configs.jl")
include("Log.jl")

include("SensitiveFloats.jl")
Expand Down
14 changes: 14 additions & 0 deletions src/Configs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module Configs

type Config
# A minimum pixel radius to be included around each source.
min_radius_pix::Float64

function Config()
config = new()
config.min_radius_pix = 8.0
config
end
end

end
17 changes: 13 additions & 4 deletions src/DeterministicVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Calculate value, gradient, and hessian of the variational ELBO.
module DeterministicVI

using Base.Threads: threadid, nthreads

import ..Configs
using ..Model
import ..Model: BivariateNormalDerivatives, BvnComponent, GalaxyCacheComponent,
GalaxySigmaDerivs, SkyPatch,
Expand Down Expand Up @@ -109,10 +111,10 @@ Arguments:
neighbors: the other light sources near `entry`
entry: the source to infer
"""
function infer_source(images::Vector{Image},
function infer_source(config::Configs.Config,
images::Vector{Image},
neighbors::Vector{CatalogEntry},
entry::CatalogEntry;
min_radius_pix=Nullable{Float64}())
entry::CatalogEntry)
if length(neighbors) > 100
msg = string("objid $(entry.objid) [ra: $(entry.pos)] has an excessive",
"number ($(length(neighbors))) of neighbors")
Expand All @@ -125,11 +127,18 @@ function infer_source(images::Vector{Image},
cat_local = vcat([entry], neighbors)
vp = init_sources([1], cat_local)
patches = Infer.get_sky_patches(images, cat_local)
Infer.load_active_pixels!(images, patches, min_radius_pix=min_radius_pix)
Infer.load_active_pixels!(config, images, patches)

ea = ElboArgs(images, vp, patches, [1])
f_evals, max_f, max_x, nm_result = NewtonMaximize.maximize!(elbo, ea)
return vp[1]
end

# legacy wrapper
function infer_source(images::Vector{Image},
neighbors::Vector{CatalogEntry},
entry::CatalogEntry)
infer_source(Configs.Config(), images, neighbors, entry)
end

end
16 changes: 12 additions & 4 deletions src/DeterministicVIImagePSF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ module DeterministicVIImagePSF

using StaticArrays, DiffBase, Compat

import ..Configs

import ..DeterministicVI:
ElboArgs, ElboIntermediateVariables,
StarPosParams, GalaxyPosParams, CanonicalParams, VariationalParams,
Expand Down Expand Up @@ -46,14 +48,14 @@ export elbo_likelihood_with_fft!, FSMSensitiveFloatMatrices,
FFTElboFunction, load_fsm_mat


function infer_source_fft(images::Vector{Image},
function infer_source_fft(config::Configs.Config,
images::Vector{Image},
neighbors::Vector{CatalogEntry},
entry::CatalogEntry;
min_radius_pix=Nullable{Float64}())
entry::CatalogEntry)
cat_local = vcat([entry], neighbors)
vp = init_sources([1], cat_local)
patches = get_sky_patches(images, cat_local)
load_active_pixels!(images, patches, min_radius_pix=min_radius_pix)
load_active_pixels!(config, images, patches)

ea_fft, fsm_mat = initialize_fft_elbo_parameters(images, vp, patches, [1], use_raw_psf=true)
elbo_fft_opt = FFTElboFunction(fsm_mat)
Expand All @@ -63,6 +65,12 @@ function infer_source_fft(images::Vector{Image},
return vp[1]
end

# legacy wrapper
function infer_source_fft(images::Vector{Image},
neighbors::Vector{CatalogEntry},
entry::CatalogEntry)
infer_source_fft(Configs.Config(), images, neighbors, entry)
end

function infer_source_fft_two_step(images::Vector{Image},
neighbors::Vector{CatalogEntry},
Expand Down
13 changes: 7 additions & 6 deletions src/GalsimBenchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ module GalsimBenchmark
using DataFrames
import FITSIO

import Celeste: AccuracyBenchmark, Infer
import Celeste.AccuracyBenchmark
import Celeste.Configs
import Celeste.Infer

const GALSIM_BENCHMARK_DIR = joinpath(Pkg.dir("Celeste"), "benchmark", "galsim")
const LATEST_FITS_FILENAME_DIR = joinpath(GALSIM_BENCHMARK_DIR, "latest_filenames")
const ACTIVE_PIXELS_MIN_RADIUS_PX = Nullable(40.0)
const ACTIVE_PIXELS_MIN_RADIUS_PX = 40.0

function get_latest_fits_filename(label)
latest_fits_filename_holder = joinpath(
Expand Down Expand Up @@ -66,10 +68,6 @@ end

function run_benchmarks(; test_case_names=String[], print_fn=println, joint_inference=false,
use_fft=false)
function infer_source_min_radius(args...; kwargs...)
infer_source_callback(args...; min_radius_pix=ACTIVE_PIXELS_MIN_RADIUS_PX, kwargs...)
end

latest_fits_filename = get_latest_fits_filename("galsim_benchmarks")
full_fits_path = joinpath(GALSIM_BENCHMARK_DIR, "output", latest_fits_filename)
extensions = AccuracyBenchmark.read_fits(full_fits_path)
Expand All @@ -89,7 +87,10 @@ function run_benchmarks(; test_case_names=String[], print_fn=println, joint_infe
truth_catalog_df = extract_catalog_from_header(header)
catalog_entries = AccuracyBenchmark.make_initialization_catalog(truth_catalog_df, false)
target_sources = collect(1:num_sources)
config = Configs.Config()
config.min_radius_pix = ACTIVE_PIXELS_MIN_RADIUS_PX
results = AccuracyBenchmark.run_celeste(
config,
catalog_entries,
target_sources,
images,
Expand Down
25 changes: 19 additions & 6 deletions src/Infer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ module Infer
import WCS
using StaticArrays

import ..Configs
using ..Model
import ..Log

Expand Down Expand Up @@ -127,14 +128,13 @@ objective function.
Non-standard arguments:
noise_fraction: The proportion of the noise below which we will remove pixels.
min_radius_pix: A minimum pixel radius to be included.
"""
function load_active_pixels!(images::Vector{Image},
function load_active_pixels!(config::Configs.Config,
images::Vector{Image},
patches::Matrix{SkyPatch};
exclude_nan=true,
noise_fraction=0.5,
min_radius_pix=Nullable{Float64}())
min_radius_pix = get(min_radius_pix, 8.0)
noise_fraction=0.5)
@show config.min_radius_pix
S, N = size(patches)

for n = 1:N, s=1:S
Expand All @@ -155,7 +155,7 @@ function load_active_pixels!(images::Vector{Image},

# include pixels that are close, even if they aren't bright
sq_dist = (h - p.pixel_center[1])^2 + (w - p.pixel_center[2])^2
if sq_dist < min_radius_pix^2
if sq_dist < config.min_radius_pix^2
p.active_pixel_bitmap[h2, w2] = true
continue
end
Expand All @@ -173,6 +173,19 @@ function load_active_pixels!(images::Vector{Image},
end
end

# legacy wrapper
function load_active_pixels!(images::Vector{Image},
patches::Matrix{SkyPatch};
exclude_nan=true,
noise_fraction=0.5)
load_active_pixels!(
Configs.Config(),
images,
patches,
exclude_nan=exclude_nan,
noise_fraction=noise_fraction,
)
end

# The range of image pixels in a vector of patches
function get_active_pixel_range(
Expand Down
28 changes: 24 additions & 4 deletions src/ParallelRun.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Base.Threads
import FITSIO
import JLD

import ..Configs
import ..Log
using ..Model
import ..SDSSIO
Expand Down Expand Up @@ -186,7 +187,8 @@ end
"""
Optimize the `ts`th element of `sources`.
"""
function process_source(ts::Int,
function process_source(config::Configs.Config,
ts::Int,
target_sources::Vector{Int},
catalog::Vector{CatalogEntry},
neighbor_map::Vector{Vector{Int}},
Expand All @@ -198,7 +200,7 @@ function process_source(ts::Int,
neighbors = catalog[neighbor_map[ts]]

tic()
vs_opt = infer_source_callback(images, neighbors, entry)
vs_opt = infer_source_callback(config, images, neighbors, entry)
ntputs(nodeid, threadid(), "processed objid $(entry.objid) in $(toq()) secs")
return OptimizedSource(entry.thing_id,
entry.objid,
Expand All @@ -212,7 +214,8 @@ end
Use multiple threads to process each target source with the specified
callback and write the results to a file.
"""
function one_node_single_infer(catalog::Vector{CatalogEntry},
function one_node_single_infer(config::Configs.Config,
catalog::Vector{CatalogEntry},
target_sources::Vector{Int},
neighbor_map::Vector{Vector{Int}},
images::Vector{Image};
Expand All @@ -238,7 +241,7 @@ function one_node_single_infer(catalog::Vector{CatalogEntry},
end

try
result = process_source(ts, target_sources, catalog, neighbor_map,
result = process_source(config, ts, target_sources, catalog, neighbor_map,
images;
infer_source_callback=infer_source_callback)

Expand Down Expand Up @@ -268,6 +271,23 @@ function one_node_single_infer(catalog::Vector{CatalogEntry},
return results
end

# legacy wrapper
function one_node_single_infer(catalog::Vector{CatalogEntry},
target_sources::Vector{Int},
neighbor_map::Vector{Vector{Int}},
images::Vector{Image};
infer_source_callback=infer_source,
timing=InferTiming())
one_node_single_infer(
Configs.Config(),
catalog,
target_sources,
neighbor_map,
images,
infer_source_callback=infer_source_callback,
timing=timing,
)
end

"""
Use mulitple threads on one node to fit the Celeste model to sources in a given
Expand Down
26 changes: 23 additions & 3 deletions src/deterministic_vi_image_psf/elbo_image_psf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -391,17 +391,17 @@ end


function initialize_fft_elbo_parameters(
config::Configs.Config,
images::Vector{Image},
vp::VariationalParams{Float64},
patches::Matrix{SkyPatch},
active_sources::Vector{Int};
use_raw_psf=true,
use_trimmed_psf=true,
allocate_fsm_mat=true,
min_radius_pix=Nullable{Float64}())

)
ea = ElboArgs(images, vp, patches, active_sources, psf_K=1)
load_active_pixels!(images, ea.patches; exclude_nan=false, min_radius_pix=min_radius_pix)
load_active_pixels!(config, images, ea.patches; exclude_nan=false)

fsm_mat = nothing
if allocate_fsm_mat
Expand All @@ -411,6 +411,26 @@ function initialize_fft_elbo_parameters(
ea, fsm_mat
end

# legacy wrapper
function initialize_fft_elbo_parameters(
images::Vector{Image},
vp::VariationalParams{Float64},
patches::Matrix{SkyPatch},
active_sources::Vector{Int};
use_raw_psf=true,
use_trimmed_psf=true,
allocate_fsm_mat=true,
)
initialize_fft_elbo_parameters(
images,
vp,
patches,
active_sources,
use_raw_psf=use_raw_psf,
use_trimmed_psf=use_trimmed_psf,
allocate_fsm_mat=allocate_fsm_mat,
)
end

@doc """
Return a function callback for an FFT elbo.
Expand Down
Loading

0 comments on commit 31b704a

Please sign in to comment.