Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subsampling nwb #78

Draft
wants to merge 10 commits into
base: variable-rois
Choose a base branch
from
1 change: 1 addition & 0 deletions src/silverlabnwb/nwb_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,7 @@ def _write_roi_data(self, all_rois, num_trials, cycles_per_trial,
# Reshape the TDMS data into an nd array
# TODO: Consider precision: the round() here is to match the exported data...
ch_data = np.round(tdms_file['Functional Imaging Data'][f'Channel {ch} Data'].data)
assert ch_data.size == total_pixels * cycles_per_trial
# Copy each ROI's data into the NWB
for roi_num, data_paths in all_rois.items():
roi_shape = all_roi_dimensions_pixels[roi_num - 1, :]
Expand Down
62 changes: 33 additions & 29 deletions src/silverlabnwb/subsample_nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,10 @@ def subsample_nwb(nwb, input_path, output_path, ntrials=2, nrois=10):
input_path, output_path,
ntrials, nrois, orig_nrois)
# Figure out time duration for given ntrials
last_trial = nwb['/epochs/trial_{:04d}'.format(ntrials)]
end_time = last_trial['stop_time'].value
end_time = nwb['/intervals/epochs/']['stop_time'][ntrials-1]
print('Trial {} ends at {}'.format(ntrials, end_time))
# Copy truncated speed data
copy_speed_data(input_path, output_path, last_trial)
copy_speed_data(input_path, output_path, nwb['/intervals/epochs/']['timeseries'][ntrials-1])
# Figure out which Zstack planes have ROIs
zstack_planes = find_used_planes(nwb, nrois)
# Copy pockels file
Expand All @@ -57,7 +56,7 @@ def subsample_nwb(nwb, input_path, output_path, ntrials=2, nrois=10):
tdms_out = os.path.join(output_path, tdms_path, tdms_name)
copy_tdms(nwb, tdms_in, tdms_out, nrois)
# Find videos defined
video_names = [name for name in nwb['/acquisition/timeseries'].keys()
video_names = [name for name in nwb['/acquisition'].keys()
if name.endswith('Cam')]
print('Videos:', video_names)
# Compress videos
Expand All @@ -77,7 +76,7 @@ def find_used_planes(nwb, nrois):
used_planes = set()
seg_iface = nwb['/processing/Acquired_ROIs/ImageSegmentation']
for plane_name in seg_iface.keys():
plane_num = int(plane_name[-4:])
plane_num = int(plane_name[6:10])
for roi_name in seg_iface[plane_name].keys():
if roi_name.startswith('ROI_'):
roi_num = int(roi_name[4:])
Expand All @@ -103,13 +102,13 @@ def copy_zstack(input_path, output_path, zstack_planes):
im.save(dest, format='TIFF', compression='tiff_lzw')


def copy_speed_data(input_path, output_path, last_trial):
def copy_speed_data(input_path, output_path, last_trial_speed_data):
fname = 'Speed_Data/Speed data 001.txt'
os.makedirs(os.path.join(output_path, 'Speed_Data'), exist_ok=True)
src = os.path.join(input_path, fname)
dest = os.path.join(output_path, fname)
speed_data_ts = last_trial['speed_data']
end_index = speed_data_ts['idx_start'].value + speed_data_ts['count'].value
speed_data_ts = last_trial_speed_data
end_index = speed_data_ts['idx_start'] + speed_data_ts['count']
copy_and_truncate(src, dest, end_index + 1)


Expand Down Expand Up @@ -189,32 +188,37 @@ def cycles_per_trial(nwb):
trial. Currently looks at the first imaging timeseries in
the first trial, and assumes they're all the same.
"""
trial1 = nwb['/epochs/trial_0001']
for ts_name in trial1:
ts = trial1[ts_name]
is_image_series = ts['timeseries/pixel_time_offsets'] is not None
if is_image_series:
return ts['count'].value
else:
raise ValueError('No imaging timeseries found')


def copy_tdms(nwb, in_path, out_path, nrois):
num_all_rois = nwb['/processing/Acquired_ROIs/roi_spec'].shape[0]
print('Copying {} of {} ROIs from {} to {}'.format(
nrois, num_all_rois, in_path, out_path))
n_all_trials = len(nwb['/intervals/trials/id'])
return np.int(nwb['acquisition/ROI_001_Red/timestamps'].shape[0] / n_all_trials)
Comment on lines +191 to +192
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get this directly from nwb['general/silverlab_optophysiology/cycles_per_trial']?



def copy_tdms(nwb, in_path, out_path, nrois, start_roi=0):
all_rois = [nwb['/acquisition/'+key] for key in nwb['/acquisition/'].keys() if key.startswith('ROI') and key.endswith('Red')]
num_all_rois = len(all_rois)
all_roi_dimensions_pixels = np.array([roi['dimension'] for roi in all_rois])
print('Copying ROIs {} to {} out of a total of {} from {} to {}'.format(
start_roi+1, nrois, num_all_rois, in_path, out_path))
in_tdms = nptdms.TdmsFile(in_path)
group_name = 'Functional Imaging Data'
with nptdms.TdmsWriter(out_path) as out_tdms:
root, group = in_tdms.object(), in_tdms.object(group_name)
out_tdms.write_segment([root, group])
group = in_tdms[group_name]
out_tdms.write_segment([group])
for ch, channel in {'0': 'Red', '1': 'Green'}.items():
ch_name = 'Channel {} Data'.format(ch)
ch_obj = in_tdms.object(group_name, ch_name)
shape = (cycles_per_trial(nwb), num_all_rois, -1)
ch_data = ch_obj.data.reshape(shape)
subset = ch_data[:, :nrois, :].reshape(-1)
new_obj = nptdms.ChannelObject(group_name, ch_name, subset, properties={})
ch_obj = in_tdms[group_name][ch_name]
# The number of pixels for each ROI for one cycle, and for all ROIs
all_roi_pixels = all_roi_dimensions_pixels.prod(axis=1)
total_pixels = all_roi_pixels.sum()
# How many pixels to keep in each cycle from the chosen ROIs
pixels_kept_per_cycle = all_roi_pixels[start_roi:nrois].sum()
preceding_pixels_per_cycle = all_roi_pixels[:start_roi].sum()
# Build the indices of the pixels we're keeping from each cycle
first_pixels = np.arange(cycles_per_trial(nwb)) * total_pixels
remaining_pixels_per_cycle = preceding_pixels_per_cycle + np.arange(pixels_kept_per_cycle)
inds = first_pixels[:, np.newaxis] + remaining_pixels_per_cycle
# Get all pixels in these ranges and store them in a 1d array
ch_data = ch_obj.data[inds].flatten()
new_obj = nptdms.ChannelObject(group_name, ch_name, ch_data, properties={})
out_tdms.write_segment([new_obj])


Expand Down