Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature:funciton branch #66

Merged
merged 6 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ dependencies = [
"fastapi_pagination==0.12.5",
"plotly==5.15.0",
"kaleido==0.2.1",
"filelock",
]

[project.optional-dependencies]
Expand Down
27 changes: 22 additions & 5 deletions studio/app/common/core/rules/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,20 @@
from dataclasses import asdict
from datetime import datetime

from filelock import FileLock

from studio.app.common.core.experiment.experiment_reader import ExptConfigReader
from studio.app.common.core.snakemake.smk import Rule
from studio.app.common.core.utils.config_handler import ConfigWriter
from studio.app.common.core.utils.filepath_creater import join_filepath
from studio.app.common.core.utils.pickle_handler import PickleReader, PickleWriter
from studio.app.const import DATE_FORMAT
from studio.app.dir_path import DIRPATH
from studio.app.optinist.core.nwb.nwb_creater import merge_nwbfile, save_nwb
from studio.app.optinist.core.nwb.nwb_creater import (
merge_nwbfile,
overwrite_nwbfile,
save_nwb,
)
from studio.app.wrappers import wrapper_dict


Expand Down Expand Up @@ -57,7 +63,7 @@ def run(cls, __rule: Rule, last_output):
if __rule.output in last_output:
# 全体の結果を保存する
path = join_filepath(os.path.dirname(os.path.dirname(__rule.output)))
path = join_filepath([path, f"whole_{__rule.type}.nwb"])
path = join_filepath([path, "whole.nwb"])
cls.save_all_nwb(path, output_info["nwbfile"])

print("output: ", __rule.output)
Expand Down Expand Up @@ -100,10 +106,21 @@ def save_func_nwb(cls, save_path, name, nwbfile, output_info):
def save_all_nwb(cls, save_path, all_nwbfile):
input_nwbfile = all_nwbfile["input"]
all_nwbfile.pop("input")
nwbfile = {}
nwbconfig = {}
for x in all_nwbfile.values():
nwbfile = merge_nwbfile(nwbfile, x)
save_nwb(save_path, input_nwbfile, nwbfile)
nwbconfig = merge_nwbfile(nwbconfig, x)
# 同一のnwbfileに対して、複数の関数を実行した場合、h5pyエラーが発生する
lock_path = save_path + ".lock"
timeout = 30 # ロック取得のタイムアウト時間(秒)
with FileLock(lock_path, timeout=timeout):
# ロックが取得できたら、ファイルに書き込みを行う
try:
if os.path.exists(save_path):
overwrite_nwbfile(save_path, nwbconfig)
else:
save_nwb(save_path, input_nwbfile, nwbconfig)
except Exception as e:
print(e)

@classmethod
def execute_function(cls, path, params, nwb_params, output_dir, input_info):
Expand Down
204 changes: 119 additions & 85 deletions studio/app/optinist/core/nwb/nwb_creater.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,22 @@ def ophys(cls, nwbfile):
return nwbfile

@classmethod
def motion_correction(cls, nwbfile, mc_data, xy_trans_data):
def add_plane_segmentation(cls, nwbfile, function_id):
image_seg = nwbfile.processing["ophys"].data_interfaces["ImageSegmentation"]
if "TwoPhotonSeries" in nwbfile.acquisition:
reference_images = nwbfile.acquisition["TwoPhotonSeries"]

image_seg.create_plane_segmentation(
name=function_id,
description="output",
imaging_plane=nwbfile.imaging_planes["ImagingPlane"],
reference_images=reference_images,
)

return nwbfile

@classmethod
def motion_correction(cls, nwbfile, function_id, mc_data, xy_trans_data):
# image_data = mc_data.data
image_path = mc_data.path
corrected = ImageSeries(
Expand Down Expand Up @@ -146,26 +161,20 @@ def motion_correction(cls, nwbfile, mc_data, xy_trans_data):
motion_correction = MotionCorrection(
corrected_image_stacks=corrected_image_stack
)
nwbfile.processing["ophys"].add(motion_correction)

return nwbfile

@classmethod
def column(cls, nwbfile, name, discription, data):
data_interfaces = nwbfile.processing["ophys"].data_interfaces
plane_seg = data_interfaces["ImageSegmentation"].plane_segmentations[
"PlaneSegmentation"
]
plane_seg.add_column(name, discription, data)
function_process = nwbfile.create_processing_module(
name=function_id, description="prcesssed by " + function_id
)
function_process.add(motion_correction)

return nwbfile

@classmethod
def roi(cls, nwbfile, roi_list):
data_interfaces = nwbfile.processing["ophys"].data_interfaces
plane_seg = data_interfaces["ImageSegmentation"].plane_segmentations[
"PlaneSegmentation"
]
def roi(cls, nwbfile, function_id, roi_list):
image_seg = nwbfile.processing["ophys"].data_interfaces["ImageSegmentation"]
nwbfile = cls.add_plane_segmentation(nwbfile, function_id)
plane_seg = image_seg.plane_segmentations[function_id]

if roi_list:
for col in roi_list[0]:
if col != "pixel_mask" and col not in plane_seg.colnames:
Expand All @@ -177,26 +186,33 @@ def roi(cls, nwbfile, roi_list):
return nwbfile

@classmethod
def fluorescence(
cls, nwbfile, table_name, region, name, data, unit, timestamps=None, rate=0.0
):
data_interfaces = nwbfile.processing["ophys"].data_interfaces
plane_seg = data_interfaces["ImageSegmentation"].plane_segmentations[
"PlaneSegmentation"
]

region_roi = plane_seg.create_roi_table_region(table_name, region=region)

roi_resp_series = RoiResponseSeries(
name=name,
data=data,
rois=region_roi,
unit=unit,
timestamps=timestamps,
rate=float(rate),
)
def column(cls, nwbfile, function_id, name, discription, data):
image_seg = nwbfile.processing["ophys"].data_interfaces["ImageSegmentation"]
plane_seg = image_seg.plane_segmentations[function_id]
plane_seg.add_column(name, discription, data)

fluo = Fluorescence(name=name, roi_response_series=roi_resp_series)
return nwbfile

@classmethod
def fluorescence(cls, nwbfile, function_id, roi_list):
image_seg = nwbfile.processing["ophys"].data_interfaces["ImageSegmentation"]
plane_seg = image_seg.plane_segmentations[function_id]
fluo = Fluorescence(name=function_id)
for key in roi_list.keys():
roi = roi_list[key]
region_roi = plane_seg.create_roi_table_region(
roi["table_name"], region=roi["region"]
)

roi_resp_series = RoiResponseSeries(
name=roi["name"],
data=roi["data"],
rois=region_roi,
unit=roi["unit"],
timestamps=roi.get("timestamps"),
rate=float(roi.get("rate", 0.0)),
)
fluo.add_roi_response_series(roi_resp_series)

nwbfile.processing["ophys"].add(fluo)

Expand Down Expand Up @@ -231,13 +247,15 @@ def behavior(cls, nwbfile, key, value):
return nwbfile

@classmethod
def postprocess(cls, nwbfile, key, value):
postprocess = PostProcess(
name=key,
data=value,
)
def postprocess(cls, nwbfile, function_id, data):
nwbfile.create_processing_module(name=function_id, description="description")

nwbfile.processing["optinist"].add_container(postprocess)
for key, value in data.items():
postprocess = PostProcess(
name=key,
data=value,
)
nwbfile.processing[function_id].add_container(postprocess)

return nwbfile

Expand Down Expand Up @@ -311,12 +329,12 @@ def reaqcuisition(cls, nwbfile):
return new_nwbfile


def save_nwb(save_path, input_config, config):
nwbfile = NWBCreater.acquisition(input_config)

def set_nwbconfig(nwbfile, config):
if NWBDATASET.POSTPROCESS in config:
for key, value in config[NWBDATASET.POSTPROCESS].items():
NWBCreater.postprocess(nwbfile, key, value)
for function_key in config[NWBDATASET.POSTPROCESS]:
NWBCreater.postprocess(
nwbfile, function_key, config[NWBDATASET.POSTPROCESS][function_key]
)

if NWBDATASET.TIMESERIES in config:
for key, value in config[NWBDATASET.TIMESERIES].items():
Expand All @@ -327,25 +345,60 @@ def save_nwb(save_path, input_config, config):
NWBCreater.behavior(nwbfile, key, value)

if NWBDATASET.MOTION_CORRECTION in config:
for mc in config[NWBDATASET.MOTION_CORRECTION].values():
nwbfile = NWBCreater.motion_correction(nwbfile, **mc)
for function_key in config[NWBDATASET.MOTION_CORRECTION]:
nwbfile = NWBCreater.motion_correction(
nwbfile,
function_key,
**config[NWBDATASET.MOTION_CORRECTION][function_key],
)

if NWBDATASET.ROI in config:
for roi_list in config[NWBDATASET.ROI].values():
NWBCreater.roi(nwbfile, roi_list)
for function_key in config[NWBDATASET.ROI]:
nwbfile = NWBCreater.roi(
nwbfile, function_key, config[NWBDATASET.ROI][function_key]
)

if NWBDATASET.COLUMN in config:
for value in config[NWBDATASET.COLUMN].values():
nwbfile = NWBCreater.column(nwbfile, **value)
for function_key in config[NWBDATASET.COLUMN]:
nwbfile = NWBCreater.column(
nwbfile, function_key, **config[NWBDATASET.COLUMN][function_key]
)

if NWBDATASET.FLUORESCENCE in config:
for value in config[NWBDATASET.FLUORESCENCE].values():
nwbfile = NWBCreater.fluorescence(nwbfile, **value)
for function_key in config[NWBDATASET.FLUORESCENCE]:
nwbfile = NWBCreater.fluorescence(
nwbfile,
function_key,
config[NWBDATASET.FLUORESCENCE][function_key],
)

return nwbfile


def save_nwb(save_path, input_config, config):
nwbfile = NWBCreater.acquisition(input_config)

nwbfile = set_nwbconfig(nwbfile, config)

with NWBHDF5IO(save_path, "w") as f:
f.write(nwbfile)


def overwrite_nwbfile(save_path, config):
tmp_save_path = os.path.join(
os.path.dirname(save_path),
"tmp_" + os.path.basename(save_path),
)
with NWBHDF5IO(save_path, "r") as src_io:
old_nwbfile = src_io.read()
nwbfile = set_nwbconfig(old_nwbfile, config)
nwbfile.set_modified()
with NWBHDF5IO(tmp_save_path, mode="w") as io:
io.export(src_io=src_io, nwbfile=nwbfile)
shutil.copyfile(tmp_save_path, save_path)
os.remove(tmp_save_path)


def overwrite_nwb(config, save_path, nwb_file_name):
# バックアップファイルを作成
nwb_path = os.path.join(save_path, nwb_file_name)
Expand All @@ -356,37 +409,10 @@ def overwrite_nwb(config, save_path, nwb_file_name):
nwbfile = io.read()
# acquisition を元ファイルから作成する
new_nwbfile = NWBCreater.reaqcuisition(nwbfile)
new_nwbfile = set_nwbconfig(new_nwbfile, config)

if NWBDATASET.POSTPROCESS in config:
for key, value in config[NWBDATASET.POSTPROCESS].items():
new_nwbfile = NWBCreater.postprocess(new_nwbfile, key, value)

if NWBDATASET.TIMESERIES in config:
for key, value in config[NWBDATASET.TIMESERIES].items():
new_nwbfile = NWBCreater.timeseries(new_nwbfile, key, value)

if NWBDATASET.BEHAVIOR in config:
for key, value in config[NWBDATASET.BEHAVIOR].items():
new_nwbfile = NWBCreater.behavior(new_nwbfile, key, value)

if NWBDATASET.MOTION_CORRECTION in config:
for mc in config[NWBDATASET.MOTION_CORRECTION].values():
new_nwbfile = NWBCreater.motion_correction(new_nwbfile, **mc)

if NWBDATASET.ROI in config:
for roi_list in config[NWBDATASET.ROI].values():
new_nwbfile = NWBCreater.roi(new_nwbfile, roi_list)

if NWBDATASET.COLUMN in config:
for value in config[NWBDATASET.COLUMN].values():
new_nwbfile = nwbfile = NWBCreater.column(new_nwbfile, **value)

if NWBDATASET.FLUORESCENCE in config:
for value in config[NWBDATASET.FLUORESCENCE].values():
new_nwbfile = NWBCreater.fluorescence(new_nwbfile, **value)

with NWBHDF5IO(tmp_nwb_path, "w") as io:
io.write(new_nwbfile)
with NWBHDF5IO(tmp_nwb_path, "w") as io:
io.write(new_nwbfile)
shutil.copyfile(tmp_nwb_path, nwb_path)
os.remove(tmp_nwb_path)

Expand All @@ -403,7 +429,15 @@ def merge_nwbfile(old_nwbfile, new_nwbfile):
NWBDATASET.IMAGE_SERIES,
]:
if pattern in old_nwbfile and pattern in new_nwbfile:
old_nwbfile[pattern].update(new_nwbfile[pattern])
for function_id in new_nwbfile[pattern]:
if function_id in old_nwbfile[pattern]:
old_nwbfile[pattern][function_id].update(
new_nwbfile[pattern][function_id]
)
else:
old_nwbfile[pattern][function_id] = new_nwbfile[pattern][
function_id
]
elif pattern in new_nwbfile:
old_nwbfile[pattern] = new_nwbfile[pattern]

Expand Down
Loading