From 2bfd3883cafe90552215f5e86d4eaf081d624914 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 30 Oct 2023 15:41:56 -0400 Subject: [PATCH 1/2] Alternate organization --- src/deepdisc/data_format/file_io.py | 128 ------------- src/deepdisc/data_format/image_readers.py | 210 ++++++++++++++++++++++ src/deepdisc/inference/predictors.py | 4 +- test_eval_model.py | 18 +- test_eval_model_DC2.py | 14 +- test_eval_model_DC2_redshift.py | 16 +- test_run_transformers.py | 16 +- test_run_transformers_DC2.py | 12 +- test_run_transformers_DC2_redshift.py | 12 +- 9 files changed, 227 insertions(+), 203 deletions(-) create mode 100644 src/deepdisc/data_format/image_readers.py diff --git a/src/deepdisc/data_format/file_io.py b/src/deepdisc/data_format/file_io.py index dd60b16..0936bc7 100644 --- a/src/deepdisc/data_format/file_io.py +++ b/src/deepdisc/data_format/file_io.py @@ -1,10 +1,6 @@ import json from pathlib import Path -import numpy as np -from astropy.visualization import make_lupton_rgb - - def get_data_from_json(filename): """Open a JSON text file, and return encoded data as dictionary. @@ -28,127 +24,3 @@ def get_data_from_json(filename): with open(filename, "r", encoding="utf-8") as f: data = json.load(f) return data - - -class ImageReader: - """Class that will read images on the fly for the training/testing dataloaders""" - - def __init__(self, reader, norm="raw", **scalekwargs): - """ - Parameters - ---------- - reader : function - This function should take a single key and return a single image as a numpy array - ex) give a filename or an index in an array - norm : str - A contrast scaling to apply before data augmentation, i.e. luptonizing or z-score scaling - Default = raw - **scalekwargs : key word args - Key word args for the contrast scaling function - """ - self.reader = reader - self.scalekwargs = scalekwargs - self.scaling = ImageReader.norm_dict[norm] - - def __call__(self, key): - """Read the image and apply scaling. - - Parameters - ---------- - key : str or int - The key indicating the image to read. - - Returns - ------- - im : numpy array - The image. - """ - im = self.reader(key) - im_scale = self.scaling(im, **self.scalekwargs) - return im_scale - - def raw(im): - """Apply raw image scaling (no scaling done). - - Parameters - ---------- - im : numpy array - The image. - - Returns - ------- - numpy array - The image with pixels as float32. - """ - return im.astype(np.float32) - - def lupton(im, bandlist=[2, 1, 0], stretch=0.5, Q=10, m=0): - """Apply Lupton scaling to the image and return the scaled image. - - Parameters - ---------- - im : np array - The image being scaled - bandlist : list[int] - Which bands to use for lupton scaling (must be 3) - stretch : float - lupton stretch parameter - Q : float - lupton Q parameter - m: float - lupton minimum parameter - - Returns - ------- - image : numpy array - The 3-channel image after lupton scaling using astropy make_lupton_rgb - """ - assert np.array(im.shape).argmin() == 2 and len(bandlist) == 3 - b1 = im[:, :, bandlist[0]] - b2 = im[:, :, bandlist[1]] - b3 = im[:, :, bandlist[2]] - - image = make_lupton_rgb(b1, b2, b3, minimum=m, stretch=stretch, Q=Q) - return image - - def zscore(im, A=1): - """Apply z-score scaling to the image and return the scaled image. - - Parameters - ---------- - im : np array - The image being scaled - A : float - A multiplicative scaling factor applied to each band - - Returns - ------- - image : numpy array - The image after z-score scaling (subtract mean and divide by std deviation) - """ - I = np.mean(im, axis=-1) - Imean = np.nanmean(I) - Isigma = np.nanstd(I) - - for i in range(im.shape[-1]): - image[:, :, i] = A * (im[:, :, i] - Imean - m) / Isigma - - return image - - #This dict is created to map an input string to a scaling function - norm_dict = {"raw": raw, "lupton": lupton} - - @classmethod - def add_scaling(cls, name, func): - """Add a custom contrast scaling function - - ex) - def sqrt(image): - image[:,:,0] = np.sqrt(image[:,:,0]) - image[:,:,1] = np.sqrt(image[:,:,1]) - image[:,:,2] = np.sqrt(image[:,:,2]) - return image - - ImageReader.add_scaling('sqrt',sqrt) - """ - cls.norm_dict[name] = func diff --git a/src/deepdisc/data_format/image_readers.py b/src/deepdisc/data_format/image_readers.py new file mode 100644 index 0000000..feeca0b --- /dev/null +++ b/src/deepdisc/data_format/image_readers.py @@ -0,0 +1,210 @@ +import abc +import os +import numpy as np + +from astropy.io import fits +from astropy.visualization import make_lupton_rgb + +class ImageReader(abc.ABC): + """Base class that will read images on the fly for the training/testing dataloaders + + To implement an image reader for a new class, the derived class needs to have an + __init__() function that calls super().__init__(*args, **kwargs) + and a custom version of _read_image(). + """ + + def __init__(self, norm="raw", *args, **kwargs): + """ + Parameters + ---------- + norm : str (optional) + A contrast scaling to apply before data augmentation, i.e. luptonizing or z-score scaling + Default = raw + **kwargs : key word args + Key word args for the contrast scaling function + """ + self.scaling = ImageReader.norm_dict[norm] + self.scalekwargs = kwargs + + @abc.abstractmethod + def _read_image(self, key): + """Read the image. No-op implementation. + + Parameters + ---------- + key : str or int + The key indicating the image to read. + + Returns + ------- + im : numpy array + The image. + """ + pass + + def __call__(self, key): + """Read the image and apply scaling. + + Parameters + ---------- + key : str or int + The key indicating the image to read. + + Returns + ------- + im : numpy array + The image. + """ + im = self._read_image(key) + im_scale = self.scaling(im, **self.scalekwargs) + return im_scale + + def raw(im): + """Apply raw image scaling (no scaling done). + + Parameters + ---------- + im : numpy array + The image. + + Returns + ------- + numpy array + The image with pixels as float32. + """ + return im.astype(np.float32) + + def lupton(im, bandlist=[2, 1, 0], stretch=0.5, Q=10, m=0): + """Apply Lupton scaling to the image and return the scaled image. + + Parameters + ---------- + im : np array + The image being scaled + bandlist : list[int] + Which bands to use for lupton scaling (must be 3) + stretch : float + lupton stretch parameter + Q : float + lupton Q parameter + m: float + lupton minimum parameter + + Returns + ------- + image : numpy array + The 3-channel image after lupton scaling using astropy make_lupton_rgb + """ + assert np.array(im.shape).argmin() == 2 and len(bandlist) == 3 + b1 = im[:, :, bandlist[0]] + b2 = im[:, :, bandlist[1]] + b3 = im[:, :, bandlist[2]] + + image = make_lupton_rgb(b1, b2, b3, minimum=m, stretch=stretch, Q=Q) + return image + + def zscore(im, A=1): + """Apply z-score scaling to the image and return the scaled image. + + Parameters + ---------- + im : np array + The image being scaled + A : float + A multiplicative scaling factor applied to each band + + Returns + ------- + image : numpy array + The image after z-score scaling (subtract mean and divide by std deviation) + """ + I = np.mean(im, axis=-1) + Imean = np.nanmean(I) + Isigma = np.nanstd(I) + + for i in range(im.shape[-1]): + image[:, :, i] = A * (im[:, :, i] - Imean - m) / Isigma + + return image + + #This dict is created to map an input string to a scaling function + norm_dict = {"raw": raw, "lupton": lupton} + + @classmethod + def add_scaling(cls, name, func): + """Add a custom contrast scaling function + + ex) + def sqrt(image): + image[:,:,0] = np.sqrt(image[:,:,0]) + image[:,:,1] = np.sqrt(image[:,:,1]) + image[:,:,2] = np.sqrt(image[:,:,2]) + return image + + ImageReader.add_scaling('sqrt',sqrt) + """ + cls.norm_dict[name] = func + + +class DC2ImageReader(ImageReader): + """An ImageReader for DC2 image files.""" + + def __init__(self, *args, **kwargs): + # Pass arguments to the parent function. + super().__init__(*args, **kwargs) + + def _read_image(self, filename): + """Read the image. + + Parameters + ---------- + filename : str + The filename indicating the image to read. + + Returns + ------- + im : numpy array + The image. + """ + file = filename.split("/")[-1].split(".")[0] + base = os.path.dirname(filename) + fn = os.path.join(base, file) + ".npy" + image = np.load(fn) + image = np.transpose(image, axes=(1, 2, 0)).astype(np.float32) + return image + + +class HSCImageReader(ImageReader): + """An ImageReader for HSC image files.""" + + def __init__(self, *args, **kwargs): + # Pass arguments to the parent function. + super().__init__(*args, **kwargs) + + def _read_image(self, filenames): + """Read the image. + + Parameters + ---------- + filenames : list + A length 3 list of filenames for the I, R, and G images. + + Returns + ------- + im : numpy array + The image. + """ + if len(filenames) != 3: + raise ValueError("Incorrect number of filenames passed.") + + g = fits.getdata(os.path.join(filenames[0]), memmap=False) + length, width = g.shape + image = np.empty([length, width, 3]) + r = fits.getdata(os.path.join(filenames[1]), memmap=False) + i = fits.getdata(os.path.join(filenames[2]), memmap=False) + + image[:, :, 0] = i + image[:, :, 1] = r + image[:, :, 2] = g + return image + \ No newline at end of file diff --git a/src/deepdisc/inference/predictors.py b/src/deepdisc/inference/predictors.py index b7221ab..e144972 100644 --- a/src/deepdisc/inference/predictors.py +++ b/src/deepdisc/inference/predictors.py @@ -30,8 +30,8 @@ def get_predictions(dataset_dict, imreader, key_mapper, predictor): ---------- dataset_dict : dictionary The dictionary metadata for a single image - imreader: ImageReader class - The ImageReader used to load in images + imreader: ImageReader object + An object derived from ImageReader base class to read the images. key_mapper: function The key_mapper should take a dataset_dict as input and return the key used by imreader predictor: AstroPredictor diff --git a/test_eval_model.py b/test_eval_model.py index ee5ebf7..1b11497 100644 --- a/test_eval_model.py +++ b/test_eval_model.py @@ -58,7 +58,8 @@ from detectron2 import structures from detectron2.structures import BoxMode -from deepdisc.data_format.file_io import ImageReader, get_data_from_json +from deepdisc.data_format.file_io import get_data_from_json +from deepdisc.data_format.image_readers import HSCImageReader from deepdisc.inference.match_objects import get_matched_object_classes from deepdisc.inference.predictors import return_predictor_transformer from deepdisc.utils.parse_arguments import dtype_from_args, make_inference_arg_parser @@ -218,19 +219,6 @@ def return_predictor( predictor, cfg = return_predictor(cfgfile, run_name, output_dir=output_dir, nc=2, roi_thresh=roi_thresh) -def hsc_image_reader(filenames): - g = fits.getdata(os.path.join(filenames[0]), memmap=False) - length, width = g.shape - image = np.empty([length, width, 3]) - r = fits.getdata(os.path.join(filenames[1]), memmap=False) - i = fits.getdata(os.path.join(filenames[2]), memmap=False) - - image[:, :, 0] = i - image[:, :, 1] = r - image[:, :, 2] = g - return image - - def hsc_key_mapper(dataset_dict): filenames = [ dataset_dict["filename_G"], @@ -240,7 +228,7 @@ def hsc_key_mapper(dataset_dict): return filenames -IR = ImageReader(hsc_image_reader, norm=args.norm) +IR = HSCImageReader(norm=args.norm) t0 = time.time() diff --git a/test_eval_model_DC2.py b/test_eval_model_DC2.py index f397109..a0d6e98 100644 --- a/test_eval_model_DC2.py +++ b/test_eval_model_DC2.py @@ -58,7 +58,8 @@ from detectron2 import structures from detectron2.structures import BoxMode -from deepdisc.data_format.file_io import ImageReader, get_data_from_json +from deepdisc.data_format.file_io import get_data_from_json +from deepdisc.data_format.image_readers import DC2ImageReader from deepdisc.inference.match_objects import get_matched_object_classes, get_matched_z_pdfs from deepdisc.inference.predictors import return_predictor_transformer from deepdisc.model.models import RedshiftPDFCasROIHeads @@ -229,21 +230,12 @@ def return_predictor( predictor, cfg = return_predictor(cfgfile, run_name, output_dir=output_dir, nc=2, roi_thresh=roi_thresh) -def dc2_image_reader(filename): - file = filename.split("/")[-1].split(".")[0] - base = os.path.dirname(filename) - fn = os.path.join(base, file) + ".npy" - image = np.load(fn) - image = np.transpose(image, axes=(1, 2, 0)).astype(np.float32) - return image - - def dc2_key_mapper(dataset_dict): filename = dataset_dict["filename"] return filename -IR = ImageReader(dc2_image_reader, norm=args.norm) +IR = DC2ImageReader(norm=args.norm) t0 = time.time() diff --git a/test_eval_model_DC2_redshift.py b/test_eval_model_DC2_redshift.py index 544f2e6..eab573a 100644 --- a/test_eval_model_DC2_redshift.py +++ b/test_eval_model_DC2_redshift.py @@ -58,7 +58,8 @@ from detectron2 import structures from detectron2.structures import BoxMode -from deepdisc.data_format.file_io import ImageReader, get_data_from_json +from deepdisc.data_format.file_io import get_data_from_json +from deepdisc.data_format.image_readers import DC2ImageReader from deepdisc.inference.match_objects import get_matched_object_classes, get_matched_z_pdfs from deepdisc.inference.predictors import return_predictor_transformer from deepdisc.model.models import RedshiftPDFCasROIHeads @@ -228,22 +229,11 @@ def return_predictor( else: predictor, cfg = return_predictor(cfgfile, run_name, output_dir=output_dir, nc=2, roi_thresh=roi_thresh) - -def dc2_image_reader(filename): - file = filename.split("/")[-1].split(".")[0] - base = os.path.dirname(filename) - fn = os.path.join(base, file) + ".npy" - image = np.load(fn) - image = np.transpose(image, axes=(1, 2, 0)).astype(np.float32) - return image - - def dc2_key_mapper(dataset_dict): filename = dataset_dict["filename"] return filename - -IR = ImageReader(dc2_image_reader, norm=args.norm) +IR = DC2ImageReader(norm=args.norm) t0 = time.time() diff --git a/test_run_transformers.py b/test_run_transformers.py index 138999a..88df938 100644 --- a/test_run_transformers.py +++ b/test_run_transformers.py @@ -66,7 +66,7 @@ from detectron2.structures import BoxMode from detectron2.utils.visualizer import Visualizer -from deepdisc.data_format.file_io import ImageReader +from deepdisc.data_format.image_readers import HSCImageReader from deepdisc.data_format.register_data import register_data_set from deepdisc.model.loaders import return_test_loader, return_train_loader, test_mapper_cls, train_mapper_cls from deepdisc.model.models import return_lazy_model @@ -181,18 +181,6 @@ def main(train_head, args): # optimizer = instantiate(cfg.optimizer) optimizer = return_optimizer(cfg) - #image_reader function takes a key and uses it to load a raw image - def hsc_image_reader(filenames): - g = fits.getdata(os.path.join(filenames[0]), memmap=False) - length, width = g.shape - image = np.empty([length, width, 3]) - r = fits.getdata(os.path.join(filenames[1]), memmap=False) - i = fits.getdata(os.path.join(filenames[2]), memmap=False) - - image[:, :, 0] = i - image[:, :, 1] = r - image[:, :, 2] = g - return image #key_mapper function should take a dataset_dict as input and output a key used by the image_reader function def hsc_key_mapper(dataset_dict): @@ -203,7 +191,7 @@ def hsc_key_mapper(dataset_dict): ] return filenames - IR = ImageReader(hsc_image_reader, norm=args.norm) + IR = HSCImageReader(norm=args.norm) mapper = train_mapper_cls(IR, hsc_key_mapper) loader = return_train_loader(cfg_loader, mapper) test_mapper = test_mapper_cls(IR, hsc_key_mapper) diff --git a/test_run_transformers_DC2.py b/test_run_transformers_DC2.py index 7ed6bf2..7fb8204 100644 --- a/test_run_transformers_DC2.py +++ b/test_run_transformers_DC2.py @@ -66,7 +66,7 @@ from detectron2.structures import BoxMode from detectron2.utils.visualizer import Visualizer -from deepdisc.data_format.file_io import ImageReader +from deepdisc.data_format.image_readers import DC2ImageReader from deepdisc.data_format.register_data import register_data_set from deepdisc.model.loaders import ( redshift_test_mapper_cls, @@ -199,19 +199,11 @@ def main(train_head, args): optimizer = return_optimizer(cfg) - def dc2_image_reader(filename): - file = filename.split("/")[-1].split(".")[0] - base = os.path.dirname(filename) - fn = os.path.join(base, file) + ".npy" - image = np.load(fn) - image = np.transpose(image, axes=(1, 2, 0)).astype(np.float32) - return image - def dc2_key_mapper(dataset_dict): filename = dataset_dict["filename"] return filename - IR = ImageReader(dc2_image_reader, norm=args.norm) + IR = DC2ImageReader(norm=args.norm) mapper = redshift_train_mapper_cls(IR, dc2_key_mapper) loader = return_train_loader(cfg_loader, mapper) test_mapper = redshift_test_mapper_cls(IR, dc2_key_mapper) diff --git a/test_run_transformers_DC2_redshift.py b/test_run_transformers_DC2_redshift.py index 243c398..b6cc183 100644 --- a/test_run_transformers_DC2_redshift.py +++ b/test_run_transformers_DC2_redshift.py @@ -66,7 +66,7 @@ from detectron2.structures import BoxMode from detectron2.utils.visualizer import Visualizer -from deepdisc.data_format.file_io import ImageReader +from deepdisc.data_format.image_readers import DC2ImageReader from deepdisc.data_format.register_data import register_data_set from deepdisc.model.loaders import ( redshift_test_mapper_cls, @@ -199,19 +199,11 @@ def main(train_head, args): optimizer = return_optimizer(cfg) - def dc2_image_reader(filename): - file = filename.split("/")[-1].split(".")[0] - base = os.path.dirname(filename) - fn = os.path.join(base, file) + ".npy" - image = np.load(fn) - image = np.transpose(image, axes=(1, 2, 0)).astype(np.float32) - return image - def dc2_key_mapper(dataset_dict): filename = dataset_dict["filename"] return filename - IR = ImageReader(dc2_image_reader, norm=args.norm) + IR = DC2ImageReader(norm=args.norm) mapper = redshift_train_mapper_cls(IR, dc2_key_mapper) loader = return_train_loader(cfg_loader, mapper) test_mapper = redshift_test_mapper_cls(IR, dc2_key_mapper) From 95a3a85997d671d2bc82421a23786ada0d5a5a7e Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 30 Oct 2023 15:53:46 -0400 Subject: [PATCH 2/2] Add tests --- tests/deepdisc/conftest.py | 16 ++++++++++++++ .../data_format/test_image_readers.py | 21 +++++++++++++++++++ 2 files changed, 37 insertions(+) create mode 100644 tests/deepdisc/data_format/test_image_readers.py diff --git a/tests/deepdisc/conftest.py b/tests/deepdisc/conftest.py index b1b9d20..3fb5843 100644 --- a/tests/deepdisc/conftest.py +++ b/tests/deepdisc/conftest.py @@ -10,3 +10,19 @@ def hsc_test_data_dir(): @pytest.fixture def hsc_single_test_file(hsc_test_data_dir): return path.join(hsc_test_data_dir, "single_test.json") + +@pytest.fixture +def hsc_triple_test_file(hsc_test_data_dir): + return [ + path.join(hsc_test_data_dir, "G-10054-0,2-c1_scarlet_img.fits"), + path.join(hsc_test_data_dir, "I-10054-0,2-c1_scarlet_img.fits"), + path.join(hsc_test_data_dir, "R-10054-0,2-c1_scarlet_img.fits"), + ] + +@pytest.fixture +def dc2_test_data_dir(): + return path.join(TEST_DIR, "test_data/dc2") + +@pytest.fixture +def dc2_single_test_file(dc2_test_data_dir): + return path.join(dc2_test_data_dir, "3828_2,2_12_images.npy") \ No newline at end of file diff --git a/tests/deepdisc/data_format/test_image_readers.py b/tests/deepdisc/data_format/test_image_readers.py new file mode 100644 index 0000000..cb055bc --- /dev/null +++ b/tests/deepdisc/data_format/test_image_readers.py @@ -0,0 +1,21 @@ +import os +import pytest + +from deepdisc.data_format.image_readers import DC2ImageReader, HSCImageReader + +def test_read_hsc_data(hsc_triple_test_file): + """Test that we can read the test DC2 data.""" + ir = HSCImageReader(norm="raw") + img = ir(hsc_triple_test_file) + assert img.shape[0] == 1050 + assert img.shape[1] == 1025 + assert img.shape[2] == 3 + + +def test_read_dc2_data(dc2_single_test_file): + """Test that we can read the test DC2 data.""" + ir = DC2ImageReader(norm="raw") + img = ir(dc2_single_test_file) + assert img.shape[0] == 525 + assert img.shape[1] == 525 + assert img.shape[2] == 6