Skip to content

Commit

Permalink
Add transit masking and polynomial normalization method
Browse files Browse the repository at this point in the history
Introduced a `mask_transit` method to define out-of-transit mask using ephemeris data and added `normalize_to_poly` method for polynomial normalization of baseline fluxes. Updated plotting functions to optionally show transit limits, and deprecated older normalization methods in favor of the new ones.
  • Loading branch information
hpparvi committed Oct 24, 2024
1 parent 268c8c1 commit 51224be
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 22 deletions.
7 changes: 1 addition & 6 deletions exoiris/exoiris.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,13 +492,8 @@ def normalize_baseline(self, deg: int = 1) -> None:
`deg` to the out-of-transit data points and divides the fluxes by the fitted polynomial evaluated
at each time point.
"""
if deg > 1:
raise ValueError("The degree of the fitted polynomial ('deg') should be 0 or 1. Higher degrees are not allowed because they could affect the transit depths.")
for d in self.data:
for ipb in range(d.nwl):
pl = poly1d(polyfit(d.time[d.ootmask], d.fluxes[ipb, d.ootmask], deg=deg))(d.time)
d.fluxes[ipb, :] /= pl
d.errors[ipb, :] /= pl
d.normalize_to_poly(deg)

def plot_baseline(self, axs: Optional[Sequence[Axes]] = None, figsize=None) -> Figure:
"""Plot the out-of-transit spectroscopic light curves before and after the normalization.
Expand Down
117 changes: 101 additions & 16 deletions exoiris/tsdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,18 @@

from astropy.io import fits as pf
from astropy.stats import mad_std
from astropy.utils import deprecated
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.pyplot import subplots, setp
from matplotlib.ticker import LinearLocator, FuncFormatter
from numpy import isfinite, median, where, concatenate, all, zeros_like, diff, asarray, interp, arange, floor, ndarray, \
ceil, newaxis, inf, array, ones, unique
ceil, newaxis, inf, array, ones, unique, poly1d, polyfit
from numpy.ma.extras import atleast_2d
from pytransit.orbits import fold
from scipy.ndimage import median_filter

from .ephemeris import Ephemeris
from .util import bin2d
from .binning import Binning, CompoundBinning

Expand All @@ -57,7 +59,7 @@ class TSData:
"""
def __init__(self, time: Sequence, wavelength: Sequence, fluxes: Sequence, errors: Sequence, name: str,
noise_group: str = 'a', wl_edges : Sequence | None = None, tm_edges : Sequence | None = None,
ootmask: ndarray | None = None) -> None:
ootmask: ndarray | None = None, ephemeris: Ephemeris | None = None) -> None:
"""
Parameters
----------
Expand Down Expand Up @@ -96,6 +98,7 @@ def __init__(self, time: Sequence, wavelength: Sequence, fluxes: Sequence, error
self.errors: ndarray = errors[m]
self.ootmask: ndarray = ootmask if ootmask is not None else ones(time.size, dtype=bool)
self.ngid: int = 0
self.ephemeris: Ephemeris | None = ephemeris
self._noise_group: str = noise_group
self._dataset: 'TSDataSet' | None = None
self._update()
Expand Down Expand Up @@ -126,6 +129,7 @@ def export_fits(self) -> pf.HDUList:
data = pf.ImageHDU(array([self.fluxes, self.errors]), name=f'data_{self.name}')
ootm = pf.ImageHDU(self.ootmask.astype(int), name=f'ootm_{self.name}')
data.header['ngroup'] = self.noise_group
#TODO: export ephemeris
return pf.HDUList([time, wave, data, ootm])

@staticmethod
Expand All @@ -135,6 +139,7 @@ def import_fits(name: str, hdul: pf.HDUList) -> 'TSData':
data = hdul[f'DATA_{name}'].data
ootm = hdul[f'OOTM_{name}'].data
noise_group = hdul[f'DATA_{name}'].header['NGROUP']
#TODO: import ephemeris
return TSData(time, wave, data[0], data[1], name=name, noise_group=noise_group, ootmask=ootm)

def __repr__(self) -> str:
Expand All @@ -150,6 +155,18 @@ def noise_group(self, ng: str) -> None:
if self._dataset is not None:
self._dataset._update_nids()

def mask_transit(self, t0: float | None = None, p: float | None = None, t14: float | None = None,
elims: tuple[int, int] | None = None) -> None:
if t0 and p and t14:
self.ephemeris = Ephemeris(t0, p, t14)
phase = fold(self.time, p, t0)
self.ootmask = abs(phase) > 0.502 * t14
elif elims is not None:
self.ootmask = ones(self.fluxes.shape, bool)
self.ootmask[:, *elims] = False
else:
raise ValueError("Transit masking requires either t0, pp, and t14, or transit limits in exposure indices.")

def calculate_ootmask(self, t0: float, p: float, t14: float):
phase = fold(self.time, p, t0)
self.ootmask = abs(phase) > 0.502 * t14
Expand All @@ -160,7 +177,58 @@ def _update(self) -> None:
self.npt = self.time.size
self.wllims = self.wavelength.min(), self.wavelength.max()

def normalize(self, s: slice) -> None:
@deprecated(0.9, alternative='normalize_to_poly')
def normalize_baseline(self, deg: int = 1) -> None:
return self.normalize_to_poly(deg)

def normalize_to_poly(self, deg: int = 1) -> None:
"""Normalize the baseline flux for each spectroscopic light curve.
Normalize the baseline flux using a low-order polynomial fitted to the out-of-transit
data for each spectroscopic light curve.
Parameters
----------
deg
The degree of the fitted polynomial. Should be 0 or 1. Higher degrees are not allowed
because they could affect the transit depths.
Raises
------
ValueError
If `deg` is greater than 1.
Notes
-----
This method normalizes the baseline of the fluxes for each planet. It fits a polynomial of degree
`deg` to the out-of-transit data points and divides the fluxes by the fitted polynomial evaluated
at each time point.
"""
if deg > 1:
raise ValueError("The degree of the fitted polynomial ('deg') should be 0 or 1. Higher degrees "
"are not allowed because they could affect the transit depths.")

if self.ootmask is None:
raise ValueError("The out-of-transit mask must be defined for normalization. "
"Call TSData.mask_transit(...) first.")

for ipb in range(self.nwl):
bl = poly1d(polyfit(self.time[self.ootmask], self.fluxes[ipb, self.ootmask], deg=deg))(self.time)
self.fluxes[ipb, :] /= bl
self.errors[ipb, :] /= bl

@deprecated(0.9, alternative='normalize_to_median')
def normalize_median(self, s: slice) -> None:
"""Normalize the light curves to the median flux of the given slice along the time axis.
Parameters
----------
s : slice
A slice object representing the portion of the data to normalize.
"""
self.normalize_to_median(s)

def normalize_to_median(self, s: slice) -> None:
"""Normalize the light curves to the median flux of the given slice along the time axis.
Parameters
Expand All @@ -185,9 +253,9 @@ def split_time(self, t: float, b: float) -> 'TSDataSet':
m1 = self.time < t-b
m2 = self.time > t+b
t1 = TSData(time=self.time[m1], wavelength=self.wavelength, fluxes=self.fluxes[:, m1], errors=self.errors[:, m1],
noise_group=self.noise_group)
noise_group=self.noise_group, ootmask=self.ootmask, ephemeris=self.ephemeris)
t2 = TSData(time=self.time[m2], wavelength=self.wavelength, fluxes=self.fluxes[:, m2], errors=self.errors[:, m2],
noise_group=self.noise_group)
noise_group=self.noise_group, ootmask=self.ootmask, ephemeris=self.ephemeris)
return t1 + t2

def partition_time(self, tlims: tuple[tuple[float,float]]) -> 'TSDataSet':
Expand All @@ -201,11 +269,12 @@ def partition_time(self, tlims: tuple[tuple[float,float]]) -> 'TSDataSet':
masks = [(self.time >= l[0]) & (self.time <= l[1]) for l in tlims]
m = masks[0]
d = TSData(time=self.time[m], wavelength=self.wavelength, fluxes=self.fluxes[:, m], errors=self.errors[:, m],
name=f'{self.name}_1', noise_group=self.noise_group)
name=f'{self.name}_1', noise_group=self.noise_group, ootmask=self.ootmask, ephemeris=self.ephemeris)
for i, m in enumerate(masks[1:]):
d = d + TSData(time=self.time[m], wavelength=self.wavelength,
fluxes=self.fluxes[:, m], errors=self.errors[:, m],
name=f'{self.name}_{i+2}', noise_group=self.noise_group)
name=f'{self.name}_{i+2}', noise_group=self.noise_group,
ootmask=self.ootmask, ephemeris=self.ephemeris)
return d

def crop_wavelength(self, lmin: float, lmax: float) -> None:
Expand Down Expand Up @@ -304,29 +373,43 @@ def plot(self, ax=None, vmin: float = None, vmax: float = None, cmap=None, figsi
"""
if ax is None:
fig, ax = subplots(figsize=figsize)
fig, ax = subplots(figsize=figsize, constrained_layout=True)
else:
fig = ax.figure
tref = floor(self.time.min())

def forward(x):
def forward_x(x):
return interp(x, self.wavelength, arange(self.nwl))
def inverse(x):
def inverse_x(x):
return interp(x, arange(self.nwl), self.wavelength)
def forward_y(y):
return interp(y, self.time-tref, arange(self.npt))
def inverse_y(y):
return interp(y, arange(self.npt), self.time-tref)

data = data if data is not None else self.fluxes
ax.pcolormesh(self.time - tref, self.wavelength, data, vmin=vmin, vmax=vmax, cmap=cmap)

if self.ephemeris is not None:
[ax.axvline(tl-tref, ls='--', c='k') for tl in self.ephemeris.transit_limits(self.time.mean())]

setp(ax, ylabel=r'Wavelength [$\mu$m]', xlabel=f'Time - {tref:.0f} [BJD]')
ax.yaxis.set_major_locator(LinearLocator(10))
ax.yaxis.set_major_formatter('{x:.2f}')

if self.name != "":
ax.set_title(self.name)

ax2 = ax.secondary_yaxis('right', functions=(forward, inverse))
ax2.set_ylabel('Light curve index')
ax2.set_yticks(forward(ax.get_yticks()))
ax2.yaxis.set_major_formatter('{x:.0f}')
axx2 = ax.secondary_yaxis('right', functions=(forward_x, inverse_x))
axx2.set_ylabel('Light curve index')
axx2.set_yticks(forward_x(ax.get_yticks()))
axx2.yaxis.set_major_formatter('{x:.0f}')
axy2 = ax.secondary_xaxis('top', functions=(forward_y, inverse_y))
axy2.set_xlabel('Exposure index')
axy2.set_xticks(forward_y(ax.get_xticks()))
axy2.xaxis.set_major_formatter('{x:.0f}')
fig.axx2 = axx2
fig.axy2 = axy2
return fig

def plot_white(self, ax=None, figsize=None) -> Axes:
Expand Down Expand Up @@ -395,7 +478,8 @@ def bin_wavelength(self, binning: Optional[Union[Binning, CompoundBinning]] = No
if not all(isfinite(be)):
warnings.warn('Error estimation failed for some bins, check the error array.')
return TSData(self.time, binning.bins.mean(1), bf, be, wl_edges=(binning.bins[:,0], binning.bins[:,1]),
name=self.name, tm_edges=(self._tm_l_edges, self._tm_r_edges), noise_group=self.noise_group)
name=self.name, tm_edges=(self._tm_l_edges, self._tm_r_edges), noise_group=self.noise_group,
ootmask=self.ootmask, ephemeris=self.ephemeris)


def bin_time(self, binning: Optional[Union[Binning, CompoundBinning]] = None,
Expand Down Expand Up @@ -433,7 +517,8 @@ def bin_time(self, binning: Optional[Union[Binning, CompoundBinning]] = None,
return TSData(binning.bins.mean(1), self.wavelength, bf.T, be.T,
wl_edges=(self._wl_l_edges, self._wl_r_edges),
tm_edges=(binning.bins[:,0], binning.bins[:,1]),
name=self.name, noise_group=self.noise_group)
name=self.name, noise_group=self.noise_group,
ootmask=self.ootmask, ephemeris=self.ephemeris)


class TSDataSet:
Expand Down

0 comments on commit 51224be

Please sign in to comment.