Skip to content

Commit

Permalink
JP-3669: Updating the C Extension to do CHARGELOSS Read Noise Recalcu…
Browse files Browse the repository at this point in the history
…lations (#275)
  • Loading branch information
tapastro authored Sep 10, 2024
2 parents c23ac15 + d85f859 commit b65f560
Show file tree
Hide file tree
Showing 10 changed files with 508 additions and 74 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ __pycache__/
*$py.class
*~

# Temp files
*.*.swp

# C extensions
*.so

Expand Down
6 changes: 6 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ General

- Add TweakReg submodule. [#267]

ramp_fitting
~~~~~~~~~~~~

- Move the CHARGELOSS read noise variance recalculation from the JWST step
code to the C extension to simplify the code and improve performance.[#275]

Changes to API
--------------

Expand Down
2 changes: 2 additions & 0 deletions changes/275.general.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[ramp_fitting] Moving the read noise recalculation due to CHARGELOSS flagging from
the JWST ramp fit step code into the STCAL ramp fit C-extension.
7 changes: 7 additions & 0 deletions src/stcal/ramp_fitting/ols_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,7 @@ def discard_miri_groups(ramp_data):
data = ramp_data.data
err = ramp_data.err
groupdq = ramp_data.groupdq
orig_gdq = ramp_data.orig_gdq

n_int, ngroups, nrows, ncols = data.shape

Expand Down Expand Up @@ -949,6 +950,8 @@ def discard_miri_groups(ramp_data):
if num_bad_slices > 0:
data = data[:, num_bad_slices:, :, :]
err = err[:, num_bad_slices:, :, :]
if orig_gdq is not None:
orig_gdq = orig_gdq[:, num_bad_slices:, :, :]

log.info("Number of leading groups that are flagged as DO_NOT_USE: %s", num_bad_slices)

Expand All @@ -968,6 +971,8 @@ def discard_miri_groups(ramp_data):
data = data[:, :-1, :, :]
err = err[:, :-1, :, :]
groupdq = groupdq[:, :-1, :, :]
if orig_gdq is not None:
orig_gdq = orig_gdq[:, :-1, :, :]

log.info("MIRI dataset has all pixels in the final group flagged as DO_NOT_USE.")

Expand All @@ -981,6 +986,8 @@ def discard_miri_groups(ramp_data):
ramp_data.data = data
ramp_data.err = err
ramp_data.groupdq = groupdq
if orig_gdq is not None:
ramp_data.orig_gdq = orig_gdq

return True

Expand Down
20 changes: 17 additions & 3 deletions src/stcal/ramp_fitting/ramp_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
BUFSIZE = 1024 * 300000 # 300Mb cache size for data section


def create_ramp_fit_class(model, dqflags=None, suppress_one_group=False):
def create_ramp_fit_class(model, algorithm, dqflags=None, suppress_one_group=False):
"""
Create an internal ramp fit class from a data model.
Expand Down Expand Up @@ -58,11 +58,24 @@ def create_ramp_fit_class(model, dqflags=None, suppress_one_group=False):
else:
dark_current_array = model.average_dark_current

orig_gdq = None
if algorithm.upper() == "OLS_C":
wh_chargeloss = np.where(np.bitwise_and(model.groupdq.astype(np.uint32), dqflags['CHARGELOSS']))
if len(wh_chargeloss[0]) > 0:
orig_gdq = model.groupdq.copy()
del wh_chargeloss

if isinstance(model.data, u.Quantity):
ramp_data.set_arrays(model.data.value, model.err.value, model.groupdq,
model.pixeldq, dark_current_array)
else:
ramp_data.set_arrays(model.data, model.err, model.groupdq, model.pixeldq, dark_current_array)
ramp_data.set_arrays(
model.data,
model.err,
model.groupdq,
model.pixeldq,
dark_current_array,
orig_gdq)

# Attribute may not be supported by all pipelines. Default is NoneType.
drop_frames1 = model.meta.exposure.drop_frames1 if hasattr(model, "drop_frames1") else None
Expand All @@ -78,6 +91,7 @@ def create_ramp_fit_class(model, dqflags=None, suppress_one_group=False):
if "zero_frame" in model.meta.exposure and model.meta.exposure.zero_frame:
ramp_data.zeroframe = model.zeroframe

ramp_data.algorithm = algorithm
ramp_data.set_dqflags(dqflags)
ramp_data.start_row = 0
ramp_data.num_rows = ramp_data.data.shape[2]
Expand Down Expand Up @@ -170,7 +184,7 @@ def ramp_fit(
# Create an instance of the internal ramp class, using only values needed
# for ramp fitting from the to remove further ramp fitting dependence on
# data models.
ramp_data = create_ramp_fit_class(model, dqflags, suppress_one_group)
ramp_data = create_ramp_fit_class(model, algorithm, dqflags, suppress_one_group)

if algorithm.upper() == "OLS_C":
ramp_data.run_c_code = True
Expand Down
28 changes: 27 additions & 1 deletion src/stcal/ramp_fitting/ramp_fit_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ def __init__(self):
self.pixeldq = None
self.average_dark_current = None

# Needed for CHARGELOSS recomputation
self.orig_gdq = None
self.algorithm = None

# Meta information
self.instrument_name = None

Expand All @@ -25,6 +29,7 @@ def __init__(self):
self.flags_saturated = None
self.flags_no_gain_val = None
self.flags_unreliable_slope = None
self.flags_chargeloss = None

# ZEROFRAME
self.zframe_mat = None
Expand All @@ -41,13 +46,15 @@ def __init__(self):

# C code debugging switch.
self.run_c_code = False
self.run_chargeloss = True
# self.run_chargeloss = False

self.one_groups_locs = None # One good group locations.
self.one_groups_time = None # Time to use for one good group ramps.

self.current_integ = -1

def set_arrays(self, data, err, groupdq, pixeldq, average_dark_current):
def set_arrays(self, data, err, groupdq, pixeldq, average_dark_current, orig_gdq=None):
"""
Set the arrays needed for ramp fitting.
Expand All @@ -72,6 +79,11 @@ def set_arrays(self, data, err, groupdq, pixeldq, average_dark_current):
average_dark_current : ndarray (float32)
2-D array containing the average dark current. It has
dimensions (nrows, ncols)
orig_gdq : ndarray
4-D array containing a copy of the original group DQ array. Since
the group DQ array can be modified during ramp fitting, this keeps
around the original group DQ flags passed to ramp fitting.
"""
# Get arrays from the data model
self.data = data
Expand All @@ -80,6 +92,8 @@ def set_arrays(self, data, err, groupdq, pixeldq, average_dark_current):
self.pixeldq = pixeldq
self.average_dark_current = average_dark_current

self.orig_gdq = orig_gdq

def set_meta(self, name, frame_time, group_time, groupgap, nframes, drop_frames1=None):
"""
Set the metainformation needed for ramp fitting.
Expand Down Expand Up @@ -131,6 +145,8 @@ def set_dqflags(self, dqflags):
self.flags_saturated = dqflags["SATURATED"]
self.flags_no_gain_val = dqflags["NO_GAIN_VALUE"]
self.flags_unreliable_slope = dqflags["UNRELIABLE_SLOPE"]
if self.algorithm is not None and self.algorithm.upper() == "OLS_C":
self.flags_chargeloss = dqflags["CHARGELOSS"]

def dbg_print_types(self):
# Arrays from the data model
Expand Down Expand Up @@ -200,6 +216,16 @@ def dbg_print_pixel_info(self, row, col):
# print(f" err :\n{self.err[:, :, row, col]}")
# print(f" pixeldq :\n{self.pixeldq[row, col]}")

def dbg_print_info(self):
print(" ")
nints, ngroups, nrows, ncols = self.data.shape
for row in range(nrows):
for col in range(ncols):
print("=" * 80)
print(f"**** Pixel ({row}, {col}) ****")
self.dbg_print_pixel_info(row, col)
print("=" * 80)

def dbg_write_ramp_data_pix_pre(self, fname, row, col, fd):
fd.write("def create_ramp_data_pixel():\n")
indent = INDENT
Expand Down
Loading

0 comments on commit b65f560

Please sign in to comment.