Skip to content

Commit

Permalink
Merge branch 'tickets/DM-42157'
Browse files Browse the repository at this point in the history
  • Loading branch information
taranu committed Mar 14, 2024
2 parents 0560859 + c1bd744 commit 8b78c75
Show file tree
Hide file tree
Showing 29 changed files with 3,758 additions and 1,436 deletions.
272 changes: 167 additions & 105 deletions examples/fithsc.ipynb

Large diffs are not rendered by default.

185 changes: 120 additions & 65 deletions examples/fithsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,38 @@
from astropy.wcs import WCS
import gauss2d as g2
import gauss2d.fit as g2f
from lsst.multiprofit.componentconfig import SersicConfig, SersicIndexConfig
from lsst.multiprofit.fit_psf import CatalogExposurePsfABC, CatalogPsfFitter, CatalogPsfFitterConfig
from lsst.multiprofit.componentconfig import SersicComponentConfig, SersicIndexParameterConfig
from lsst.multiprofit.fit_psf import (
CatalogExposurePsfABC,
CatalogPsfFitter,
CatalogPsfFitterConfig,
CatalogPsfFitterConfigData,
)
from lsst.multiprofit.fit_source import (
CatalogExposureSourcesABC,
CatalogSourceFitterABC,
CatalogSourceFitterConfig,
CatalogSourceFitterConfigData,
)
from lsst.multiprofit.modelconfig import ModelConfig
from lsst.multiprofit.plots import plot_model_rgb
from lsst.multiprofit.utils import ArbitraryAllowedConfig
from lsst.multiprofit.sourceconfig import ComponentGroupConfig, SourceConfig
from lsst.multiprofit.utils import ArbitraryAllowedConfig, get_params_uniq
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pydantic
from pydantic.dataclasses import dataclass


# In[2]:


# Define settings
band_ref = 'i'
bands = {'i': 0.87108833, 'r': 0.97288654, 'g': 1.44564678}
band_multi = ''.join(bands)
channels = {band: g2f.Channel.get(band) for band in bands}

# This is in the WCS, but may as well keep full precision
scale_pixel_hsc = 0.168
Expand Down Expand Up @@ -185,13 +195,18 @@ def get_psf_image(self, source: apTab.Row | Mapping[str, Any]) -> np.array:
catalog_psf = apTab.Table({'object_id': [tab_row_main['object_id']]})
results_psf = {}

# Keep a separate configdata_psf per band, because it has a cached PSF model
# those should not be shared!
config_data_psfs = {}
for band, psf_file in psfs.items():
config_data_psf = CatalogPsfFitterConfigData(config=config_psf)
catexp = CatalogExposurePsf(catalog=catalog_psf, img=psf_file[0].data)
t_start = time.time()
result = fitter_psf.fit(config=config_psf, catexp=catexp)
result = fitter_psf.fit(config_data=config_data_psf, catexp=catexp)
t_end = time.time()
results_psf[band] = result
print(f"Fit {band}-band PSF in {t_end - t_start:.2f}s; result:")
config_data_psfs[band] = config_data_psf
print(f"Fit {band}-band PSF in {t_end - t_start:.2e}s; result:")
print(dict(result[0]))


Expand All @@ -200,20 +215,33 @@ def get_psf_image(self, source: apTab.Row | Mapping[str, Any]) -> np.array:

# Set fit configs
config_source = CatalogSourceFitterConfig(
column_id = 'object_id',
n_pointsources=1,
sersics={
'disk': SersicConfig(
sersicindex=SersicIndexConfig(value_initial=1., fixed=True),
prior_size_stddev=0.3,
prior_axrat_stddev=0.2,
),
'bulge': SersicConfig(
sersicindex=SersicIndexConfig(value_initial=4., fixed=True),
prior_size_stddev=0.3,
prior_axrat_stddev=0.2,
),
}
column_id='object_id',
config_model=ModelConfig(
sources={
"src": SourceConfig(
component_groups={
"": ComponentGroupConfig(
components_sersic={
'disk': SersicComponentConfig(
sersic_index=SersicIndexParameterConfig(value_initial=1., fixed=True),
prior_size_stddev=0.5,
prior_axrat_stddev=0.2,
),
'bulge': SersicComponentConfig(
sersic_index=SersicIndexParameterConfig(value_initial=4., fixed=True),
prior_size_stddev=0.1,
prior_axrat_stddev=0.2,
),
},
),
}
),
},
),
)
config_data_source = CatalogSourceFitterConfigData(
channels=list(channels.values()),
config=config_source,
)


Expand All @@ -223,7 +251,7 @@ def get_psf_image(self, source: apTab.Row | Mapping[str, Any]) -> np.array:
# Setup exposure with band-specific image, mask and variance
@dataclass(frozen=True, config=ArbitraryAllowedConfig)
class CatalogExposureSources(CatalogExposureSourcesABC):
config_psf: CatalogPsfFitterConfig = pydantic.Field(title="The PSF fit config")
config_data_psf: CatalogPsfFitterConfigData = pydantic.Field(title="The PSF fit config")
observation: g2f.Observation = pydantic.Field(title="The observation to fit")
table_psf_fits: apTab.Table = pydantic.Field(title="The table of PSF fit parameters")

Expand All @@ -234,8 +262,9 @@ def channel(self) -> g2f.Channel:
def get_catalog(self) -> Iterable:
return self.table_psf_fits

def get_psfmodel(self, params: Mapping[str, Any]) -> g2f.PsfModel:
return self.config_psf.rebuild_psfmodel(params)
def get_psf_model(self, params: Mapping[str, Any]) -> g2f.PsfModel:
self.config_data_psf.init_psf_model(params)
return self.config_data_psf.psf_model

def get_source_observation(self, source: Mapping[str, Any]) -> g2f.Observation:
return self.observation
Expand All @@ -251,44 +280,66 @@ def initialize_model(
self,
model: g2f.Model,
source: Mapping[str, Any],
limits_x: g2f.LimitsD = None,
limits_y: g2f.LimitsD = None,
) -> None:
catexps: list[CatalogExposureSourcesABC],
values_init: Mapping[g2f.ParameterD, float] | None = None,
centroid_pixel_offset: float = 0,
**kwargs
):
if values_init is None:
values_init = {}
x, y = source['x'], source['y']
scale_sq = self.scale_pixel**(-2)
ellipse = g2.Ellipse(g2.Covariance(
sigma_x_sq=source[f'{band}_sdssshape_shape11']*scale_sq,
sigma_y_sq=source[f'{band}_sdssshape_shape22']*scale_sq,
cov_xy=source[f'{band}_sdssshape_shape12']*scale_sq,
))
values_init = {
g2f.CentroidXParameterD: x,
g2f.CentroidYParameterD: y,
g2f.ReffXParameterD: ellipse.sigma_x,
g2f.ReffYParameterD: ellipse.sigma_y,
# There is a sign convention difference
g2f.RhoParameterD: -ellipse.rho,
}
size_major = g2.EllipseMajor(ellipse).r_major
limits_size = g2f.LimitsD(1e-5, np.sqrt(x*x + y*y))
# An R_eff larger than the box size is problematic
# Also should stop unreasonable size proposals; log10 transform isn't enough
# TODO: Try logit for r_eff?
limits_init = {
g2f.IntegralParameterD: g2f.LimitsD(1e-6, 1e10),
g2f.ReffXParameterD: g2f.LimitsD(1e-5, x),
g2f.ReffYParameterD: g2f.LimitsD(1e-5, y),
params_limits_init = {
# Should set limits based on image size, but this shortcut is fine
# for this particular object
g2f.CentroidXParameterD: (x, g2f.LimitsD(0, 2*x)),
g2f.CentroidYParameterD: (x, g2f.LimitsD(0, 2*y)),
g2f.ReffXParameterD: (ellipse.sigma_x, limits_size),
g2f.ReffYParameterD: (ellipse.sigma_y, limits_size),
# There is a sign convention difference
g2f.RhoParameterD: (-ellipse.rho, None),
g2f.IntegralParameterD: (1.0, g2f.LimitsD(1e-10, 1e10)),
}
for component in model.sources[0].components:
params_free = component.parameters(paramfilter=g2f.ParamFilter(fixed=False))
for param in params_free:
type_param = type(param)
if (value := values_init.get(type_param)) is not None:
param.value = value
if (limits := limits_init.get(type_param)) is not None:
param.limits = limits
params_free = get_params_uniq(model, fixed=False)
for param in params_free:
type_param = type(param)
value_init, limits_new = params_limits_init.get(
type_param,
(values_init.get(param), None)
)
if value_init is not None:
param.value = value_init
if limits_new:
# For slightly arcane reasons, we must set a new limits object
# Changing limits values is unreliable
param.limits = limits_new
for prior in model.priors:
if isinstance(prior, g2f.ShapePrior):
prior.prior_size.mean_parameter.value = g2.EllipseMajor(ellipse).r_major
prior.prior_size.mean_parameter.value = size_major


def validate_fit_inputs(
self,
catalog_multi,
catexps: list[CatalogExposureSourcesABC],
config_data: CatalogSourceFitterConfigData = None,
logger = None,
**kwargs: Any,
) -> None:
super().validate_fit_inputs(
catalog_multi=catalog_multi, catexps=catexps, config_data=config_data,
logger=logger, **kwargs
)


# In[9]:
Expand Down Expand Up @@ -321,7 +372,7 @@ def initialize_model(
)
observations[band] = observation
catexps[band] = CatalogExposureSources(
config_psf=config_psf,
config_data_psf=config_data_psfs[band],
observation=observation,
table_psf_fits=results_psf[band],
)
Expand All @@ -335,10 +386,10 @@ def initialize_model(
result_multi = fitter.fit(
catalog_multi=tab_row_main,
catexps=list(catexps.values()),
config=config_source,
config_data=config_data_source,
)
t_end = time.time()
print(f"Fit {','.join(bands.keys())}-band bulge-disk model in {t_end - t_start:.2f}s; result:")
print(f"Fit {','.join(bands.keys())}-band bulge-disk model in {t_end - t_start:.2e}s; result:")
print(dict(result_multi[0]))


Expand All @@ -348,11 +399,15 @@ def initialize_model(
# Fit in each band separately
results = {}
for band, observation in bands.items():
config_data_source_band = CatalogSourceFitterConfigData(
channels=[channels[band]],
config=config_source,
)
t_start = time.time()
result = fitter.fit(
catalog_multi=tab_row_main,
catexps=[catexps[band]],
config=config_source,
config_data=config_data_source_band,
)
t_end = time.time()
results[band] = result
Expand All @@ -364,8 +419,9 @@ def initialize_model(


# Make a model for the best-fit params
model, data, *_ = config_source.make_model_data(idx_source=0, catexps=list(catexps.values()))
params = list({p: None for p in model.parameters(paramfilter=g2f.ParamFilter(fixed=False))}.keys())
data, psf_models = config_source.make_model_data(idx_row=0, catexps=list(catexps.values()))
model = g2f.Model(data=data, psfmodels=psf_models, sources=config_data_source.sources_priors[0], priors=config_data_source.sources_priors[1])
params = get_params_uniq(model, fixed=False)
result_multi_row = dict(result_multi[0])
# This is the last column before fit params
idx_last = next(idx for idx, column in enumerate(result_multi_row.keys()) if column == 'mpf_unknown_flag')
Expand All @@ -379,17 +435,17 @@ def initialize_model(


# ### Multiband Residuals
#
#
# What's with the structure in the residuals? Most broadly, a point source + exponential disk + deVauc bulge model is totally inadequate for this galaxy for several possible reasons:
#
#
# 1. The disk isn't exactly exponential (n=1)
# 2. The disk has colour gradients not accounted for in this model*
# 3. If the galaxy even has a bulge, it's very weak and def. not a deVaucouleurs (n=4) profile; it may be an exponential "pseudobulge"
#
#
# \*MultiProFit can do more general Gaussian mixture models (linear or non-linear), which may be explored in a future iteration of this notebook, but these are generally do not improve the accuracy of photometry for smaller/fainter galaxies.
#
#
# Note that the two scalings of the residual plots (98%ile and +/- 20 sigma) end up looking very similar.
#
#

# In[13]:

Expand All @@ -408,27 +464,26 @@ def initialize_model(
packed = np.packbits(mask_inv_highsn, bitorder='little')
np.savez_compressed(f'{prefix_img}mask_inv_highsn.npz', mask_inv=mask_highsn)

# TODO: Plotting functions will be refactored from old MPF
# Missing features include colour residual images
# Also complete labels, etc.
# TODO: Some features still missing from plot_model_rgb
# residual histograms, param values, better labels, etc


# ### More exercises for the reader
#
#
# These are of the sort that the author hasn't gotten around to yet because they're far from trivial. Try:
#
#
# 0. Use the WCS to compute ra, dec and errors thereof.
# Hint: override CatalogSourceFitter.get_model_radec
#
#
# 1. Replace the real data with simulated data.
# Make new observations using model.evaluate and add noise based on the variance maps.
# Try fitting again and see how well results converge depending on the initialization scheme.
#
#
# 2. Fit every other source individually.
# Try subtracting the best-fit galaxy model from above first.
# Hint: get_source_observation should be redefined to return a smaller postage stamp around the nominal centroid.
# Pass the full catalog (excluding the central galaxy) to catalog_multi.
#
#
# 3. Fit all sources simultaneously.
# Redefine CatalogFitterConfig.make_model_data to make a model with multiple sources, using the catexp catalogs
# initialize_model will no longer need to do anything
Expand Down
4 changes: 2 additions & 2 deletions examples/plot_sersic_mix.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import gauss2d.fit as g2f
from lsst.multiprofit.plots import plot_sersicmix_interp
import matplotlib.pyplot as plt
import numpy as np
from lsst.multiprofit.plots import plot_sersicmix_interp
from scipy.interpolate import CubicSpline

interps = {
Expand All @@ -10,7 +10,7 @@
"scipy-csp": ((CubicSpline, {}), (0, (4, 4))),
}

for n_low, n_hi in ((0.5, 0.7), (2.2, 4.4)):
for n_low, n_hi in ((0.5, 0.7), (0.8, 1.2), (2.2, 4.4)):
n_ser = 10 ** np.linspace(np.log10(n_low), np.log10(n_hi), 100)
plot_sersicmix_interp(interps=interps, n_ser=n_ser, figsize=(10, 8))
plt.tight_layout()
Expand Down
Loading

0 comments on commit 8b78c75

Please sign in to comment.