Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preprocessing #8

Merged
merged 11 commits into from
May 18, 2023
14 changes: 13 additions & 1 deletion src/membrain_seg/dataloading/data_utils.py
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor adjustments to normalize tomograms and return pixel size from the header.

Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def read_nifti(nifti_file):
return a


def load_tomogram(filename, return_header=False, normalize_data=False):
def load_tomogram(
filename, return_pixel_size=False, return_header=False, normalize_data=False
):
"""
Loads data and transposes s.t. we have data in the form x,y,z.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Loads data and transposes s.t. we have data in the form x,y,z.
Loads data and transposes s.t. we have data in the form x,y,z.

I hadn't noticed this before - could you explain why the transpose is necessary?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mrcfile package loads the tomograms by default in format (z, x, y) if I remember correctly. It's not really necessary to transpose the axes. I just find it more intuitive to have the axes ordered (x, y, z). Removing this transpose should not matter for any functionalities, though.


Expand All @@ -73,8 +75,11 @@ def load_tomogram(filename, return_header=False, normalize_data=False):
"pixel_spacing": pixel_spacing,
}
if normalize_data:
data = data.astype(float)
data -= np.mean(data)
data /= np.std(data)
if return_pixel_size:
return data, mrc.voxel_size
if return_header:
return data, header_dict
return data
Expand All @@ -91,3 +96,10 @@ def store_tomogram(filename, tomogram, header_dict=None):
mrc.header.cella = header_dict["cella"]
mrc.header.cellb = header_dict["cellb"]
mrc.header.origin = header_dict["origin"]


def normalize_tomogram(tomogram):
"""Normalize tomogram to zero mean and unit standard deviation."""
tomogram -= np.mean(tomogram)
tomogram /= np.std(tomogram)
return tomogram
2 changes: 1 addition & 1 deletion src/membrain_seg/dataloading/memseg_augmentation.py
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found a small error in the augmentation script: if "prob_to_one" (i.e. maximum augmentations) is specified, it should apply the flipping with probability 0.5 instead of 1.0. Otherwise the flipping always happens and is redundant.
(Probably I should have added this in another branch, but I hope it's also okay here)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all good here for sure!

Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def get_training_transforms(prob_to_one=False, return_as_list=False):
),
prob=(1.0 if prob_to_one else 0.25),
),
RandAxisFlipd(keys=("image", "label"), prob=(1.0 if prob_to_one else 0.5)),
RandAxisFlipd(keys=("image", "label"), prob=(0.5)),
BlankCuboidTransform(
keys=["image"],
prob=(1.0 if prob_to_one else 0.4),
Expand Down
3 changes: 1 addition & 2 deletions src/membrain_seg/networks/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def on_train_epoch_end(self):
outputs = self.training_step_outputs
train_loss, num_items = 0, 0
for output in outputs:
train_loss += output["train_loss"].sum().item()
train_loss += output["train_loss"].sum().item() * output["train_number"]
num_items += output["train_number"]
mean_train_loss = torch.tensor(train_loss / num_items)

Expand Down Expand Up @@ -262,7 +262,6 @@ def on_validation_epoch_end(self):
self.dice_metric.reset()
mean_val_loss = torch.tensor(val_loss / num_items)


mean_val_acc = self.running_val_acc / num_items
self.running_val_acc = 0.0
self.log("val_loss", mean_val_loss), # batch_size=num_items)
Expand Down
1 change: 1 addition & 0 deletions src/tomo_preprocessing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Empty init."""
77 changes: 77 additions & 0 deletions src/tomo_preprocessing/extract_spectrum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# --------------------------------------------------------------------------------
# Copyright (C) 2022 ZauggGroup
#
# This file is a copy (or a modified version) of the original file from the
# following GitHub repository:
#
# Repository: https://github.com/ZauggGroup/DeePiCt
# Original file: https://github.com/ZauggGroup/DeePiCt/blob/main/spectrum_filter/extract_spectrum.py
# Repository URL: https://github.com/ZauggGroup/DeePiCt
# Original author(s): de Teresa, I.*, Goetz S.K.*, Mattausch, A., Stojanovska, F.,
# Zimmerli C., Toro-Nahuelpan M., Cheng, D.W.C., Tollervey, F. , Pape, C.,
# Beck, M., Diz-Muñoz, A., Kreshuk, A., Mahamid, J. and Zaugg, J.
# License: Apache License 2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------------

import argparse

from membrain_seg.dataloading.data_utils import load_tomogram, normalize_tomogram

from tomo_preprocessing.matching_utils.spec_matching_utils import extract_spectrum


def main():
"""Extract the spectrum from the target tomogram."""
# Parse command line arguments.
parser = get_cli()
args = parser.parse_args()

# Read input tomogram.
tomo = args.input
tomo = normalize_tomogram(tomo)

# Extract amplitude spectrum.
spectrum = extract_spectrum(tomo)

# Save the spectrum to a file
spectrum.to_csv(args.output, sep="\t", header=["intensity"], index_label="freq")


def get_cli():
"""Function to set up the command line interface."""
parser = argparse.ArgumentParser(
description="Extract radially averaged amplitude spectrum from cryo-ET data."
)

parser.add_argument(
"-i",
"--input",
required=True,
type=load_tomogram,
help="Tomogram to extract spectrum from (.mrc/.rec format)",
)

parser.add_argument(
"-o",
"--output",
required=True,
help="Output destination for extracted spectrum (.tsv format)",
)

return parser


if __name__ == "__main__":
main()
88 changes: 88 additions & 0 deletions src/tomo_preprocessing/match_pixel_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import argparse
import os

from membrain_seg.dataloading.data_utils import (
load_tomogram,
normalize_tomogram,
store_tomogram,
)

from tomo_preprocessing.matching_utils.px_matching_utils import (
determine_output_shape,
fourier_cropping,
fourier_extend,
)


def main():
"""Match the pixel size if your input tomo to the target."""
# Parse the command-line arguments
parser = get_cli()
args = parser.parse_args()

# Load the input tomogram and its pixel size
file_path = args.input_tomogram
data, input_pixel_sizes = load_tomogram(
file_path, return_pixel_size=True, normalize_data=True
)
pixel_size_in = args.pixel_size_in or input_pixel_sizes.x
smoothing = not args.disable_smooth

print(
"Matching input tomogram",
os.path.basename(file_path),
"from pixel size",
pixel_size_in,
"to pixel size",
args.pixel_size_out,
".",
)

# Calculate the output shape after pixel size matching
output_shape = determine_output_shape(
pixel_size_in, args.pixel_size_out, data.shape
)

# Perform Fourier-based resizing (cropping or extending) using the determined
# output shape
if (pixel_size_in / args.pixel_size_out) < 1.0:
resized_data = fourier_cropping(data, output_shape, smoothing)
else:
resized_data = fourier_extend(data, output_shape, smoothing)

resized_data = normalize_tomogram(resized_data)
# Save the resized tomogram to the specified output path
store_tomogram(args.output_path, resized_data)


def get_cli():
"""Command line interface parser."""
# Set up the argument parser
parser = argparse.ArgumentParser(description="Match tomogram pixel size")
parser.add_argument("input_tomogram", help="Path to the input tomogram")
parser.add_argument("output_path", help="Path to store the output files")
parser.add_argument(
"--pixel_size_out",
type=float,
default=10.0,
help="Target pixel size (default: 20.0)",
)
parser.add_argument(
"--pixel_size_in",
type=float,
default=None,
help="Input pixel size (optional). If not specified, it will be read"
"from the tomogram's header. ATTENTION: This can lead to severe errors if the"
"header pixel size is not correct.",
)
parser.add_argument(
"--disable_smooth",
type=bool,
default=False,
help="Disable smoothing (ellipsoid mask + cosine decay). Disable if "
"causing problems or for speed up",
)


if __name__ == "__main__":
main()
131 changes: 131 additions & 0 deletions src/tomo_preprocessing/match_spectrum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# --------------------------------------------------------------------------------
# Copyright (C) 2022 ZauggGroup
#
# This file is a copy (or a modified version) of the original file from the
# following GitHub repository:
#
# Repository: https://github.com/ZauggGroup/DeePiCt
# Original file: https://github.com/ZauggGroup/DeePiCt/blob/main/spectrum_filter/match_spectrum.py
# Repository URL: https://github.com/ZauggGroup/DeePiCt
# Original author(s): de Teresa, I.*, Goetz S.K.*, Mattausch, A., Stojanovska, F.,
# Zimmerli C., Toro-Nahuelpan M., Cheng, D.W.C., Tollervey, F. , Pape, C.,
# Beck, M., Diz-Muñoz, A., Kreshuk, A., Mahamid, J. and Zaugg, J.
# License: Apache License 2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------------

import argparse

import pandas as pd
from membrain_seg.dataloading.data_utils import (
load_tomogram,
normalize_tomogram,
store_tomogram,
)
from membrain_seg.parse_utils import str2bool

from tomo_preprocessing.matching_utils.spec_matching_utils import match_spectrum


def main():
"""Match the input tomogram's spectrum to the target spectrum."""
# Parse command line arguments
parser = get_cli()
args = parser.parse_args()

# Read input tomogram
tomo = load_tomogram(args.input, normalize_data=True)

# Read target spectrum
target_spectrum = pd.read_csv(args.target, sep="\t")["intensity"].values

# Match the amplitude spectrum of the input tomogram to the target spectrum
filtered_tomo = match_spectrum(
tomo,
target_spectrum,
args.cutoff,
args.smoothen,
args.almost_zero_cutoff,
args.shrink_excessive_value,
)
filtered_tomo = normalize_tomogram(filtered_tomo)
# Save the filtered tomogram to a file
store_tomogram(args.output, filtered_tomo)


def get_cli():
"""Set up the command line interface."""
parser = argparse.ArgumentParser(
description="Match tomogram to another tomogram's amplitude spectrum."
)

parser.add_argument(
"-i", "--input", required=False, help="Tomogram to match (.mrc/.rec)"
)

parser.add_argument(
"-t",
"--target",
required=False,
help="Target spectrum to match the input tomogram to (.tsv)",
)

parser.add_argument(
"-o", "--output", required=False, help="Output location for matched tomogram"
)

parser.add_argument(
"-c",
"--cutoff",
required=False,
default=False,
type=int,
help="Lowpass cutoff to apply. All frequencies above this value will \
be set to zero.",
)

parser.add_argument(
"--shrink_excessive_value",
required=False,
default=5e1,
type=int,
help="Regularization for excessive values. All Fourier coefficients \
above this values will be set to the value.",
)

parser.add_argument(
"--almost_zero_cutoff",
type=str2bool,
default=True,
help='Pass "True" or "False". Should Fourier coefficients close to \
zero be ignored?\
Recommended particularly in combination with pixel size matching. \
Defaults to True. ',
)

parser.add_argument(
"-s",
"--smoothen",
required=False,
default=10,
type=float,
help="Smoothening to apply to lowpass filter. Value roughly resembles sigmoid"
" width in pixels",
)

return parser


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions src/tomo_preprocessing/matching_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Empty init."""
Loading