Skip to content

Commit

Permalink
Write to single zarr store (#340)
Browse files Browse the repository at this point in the history
* Write to single zarr store

* add dimension number to phase channel name

* remove unnecessary ellipses

* close dataset

* save all reconstructions to `ReconstructionSnap.zarr`
  • Loading branch information
talonchandler authored Apr 13, 2023
1 parent 8c400d6 commit 190443a
Showing 1 changed file with 31 additions and 44 deletions.
75 changes: 31 additions & 44 deletions recOrder/acq/acquisition_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,6 @@ def _reconstruct(self, stack):

# Perform deconvolution
if self.dim == "2D":

phase = reconstruct_phase2D(
stack[0],
recon,
Expand All @@ -308,7 +307,6 @@ def _reconstruct(self, stack):
rho=float(self.calib_window.ui.le_rho.text()),
)
else:

phase = reconstruct_phase3D(
stack[0],
recon,
Expand Down Expand Up @@ -346,7 +344,11 @@ def _save_imgs(self, phase, meta=None):
"""
prefix = self.calib_window.save_name
name = f"PhaseSnap.zarr" if not prefix else f"{prefix}_PhaseSnap.zarr"
name = (
f"ReconstructionSnap.zarr"
if not prefix
else f"{prefix}_ReconstructionSnap.zarr"
)

with open_ome_zarr(
os.path.join(self.snap_dir, name),
Expand Down Expand Up @@ -426,7 +428,6 @@ def _reconstructor_changed(self, stack_shape: tuple):
return False

def _cleanup_acq(self):

# Get display windows
disps = self.dm.getAllDataViewers()

Expand Down Expand Up @@ -725,7 +726,6 @@ def _reconstruct(self, stack):

# Initialize the heavy reconstuctor
if self.mode == "phase" or self.mode == "all":

self._check_abort()

# if no reconstructor has been initialized before, create new reconstructor
Expand Down Expand Up @@ -928,12 +928,12 @@ def _reconstruct(self, stack):

def _save_imgs(self, birefringence, phase, meta=None):
"""
function to save images. Seperates out both birefringence and phase into separate zarr stores.
function to save images.
Makes sure file names do not overlap, i.e. nothing is overwritten.
Parameters
----------
birefringence: (nd-array or None) birefringence image(s)
birefringence: (nd-array) birefringence image(s)
phase: (nd-array or None) phase image(s)
Returns
Expand All @@ -942,39 +942,32 @@ def _save_imgs(self, birefringence, phase, meta=None):
"""
prefix = self.calib_window.save_name

if birefringence is not None:
name = (
f"ReconstructionSnap.zarr"
if not prefix
else f"{prefix}_ReconstructionSnap.zarr"
)

name = (
f"BirefringenceSnap.zarr"
if not prefix
else f"{prefix}_BirefringenceSnap.zarr"
)
with open_ome_zarr(
os.path.join(self.snap_dir, name),
layout="fov",
mode="w-",
channel_names=["Retardance", "Orientation", "BF", "Pol"],
) as dataset:
if birefringence.ndim == 3:
birefringence = birefringence[:, np.newaxis, ...]
dataset["0"] = birefringence[np.newaxis, ...]
dataset.zattrs["recOrder"] = meta

if phase is not None:
name = (
f"PhaseSnap.zarr" if not prefix else f"{prefix}_PhaseSnap.zarr"
)
with open_ome_zarr(
os.path.join(self.snap_dir, name),
layout="fov",
mode="w-",
channel_names=["Retardance", "Orientation", "BF", "Pol"],
) as dataset:
if birefringence.ndim == 3:
birefringence = birefringence[:, np.newaxis] # CYX -> CZYX
birefringence = birefringence[np.newaxis] # CZYX -> TCZYX
dataset["0"] = birefringence

if phase is not None:
dataset.append_channel(
"Phase" + str(phase.ndim) + "D", resize_arrays=True
)
if phase.ndim == 2:
phase = phase[np.newaxis] # YX -> ZYX
dataset["0"][0, 4] = phase

with open_ome_zarr(
os.path.join(self.snap_dir, name),
layout="fov",
mode="w-",
channel_names=["Phase" + str(phase.ndim) + "D"],
) as dataset:
dataset["0"] = phase[
(5 - phase.ndim) * (np.newaxis,) + (Ellipsis,)
]
dataset.zattrs["recOrder"] = meta
dataset.zattrs["recOrder"] = meta

def _load_bg(self, path, height, width):
"""
Expand Down Expand Up @@ -1110,7 +1103,6 @@ def _reconstructor_changed(self):
return changed

def _cleanup_acq(self):

# Get display windows
disps = self.dm.getAllDataViewers()

Expand Down Expand Up @@ -1283,7 +1275,6 @@ def listen_for_images(
for dim2 in range(dims[1][0], dims[1][1]):
for dim1 in range(dims[2][0], dims[2][1]):
for dim0 in range(dims[3][0], dims[3][1]):

# GET OFFSET AND WAIT
if idx > 0:
try:
Expand Down Expand Up @@ -1312,7 +1303,6 @@ def listen_for_images(
* self.n_frames
* self.n_pos
):

# Assign dimensions based off acquisition order to correctly add image to array
if dim_order == 0:
t, p, c, z = dim3, dim2, dim0, dim1
Expand Down Expand Up @@ -1359,7 +1349,6 @@ def listen_for_images(

# If z-first, compute the birefringence here
if channel_dim == 1 and dim1 == self.n_channels - 1:

# Assign dimensions based off acquisition order to correctly add image to array
if dim_order == 0:
t, p, c, z = dim3, dim2, dim0, dim1
Expand Down Expand Up @@ -1388,7 +1377,6 @@ def listen_for_images(
return array, idx, dim3, dim2, dim1, dim0

def compute_and_save(self, array, p, t, z):

if self.n_slices == 1:
array = array[:, 0]

Expand Down Expand Up @@ -1505,7 +1493,6 @@ def work(self):
while total_idx < (
self.n_slices * self.n_channels * self.n_frames * self.n_pos
):

self._check_abort()

# this will loop through reading images in a single file as it's being written
Expand Down

0 comments on commit 190443a

Please sign in to comment.