Skip to content

Commit

Permalink
Merge pull request #53 from lincc-frameworks/refactor
Browse files Browse the repository at this point in the history
Create ImageReader subclasses
  • Loading branch information
jeremykubica authored Oct 30, 2023
2 parents 8a7ebe6 + 95a3a85 commit bedefa9
Show file tree
Hide file tree
Showing 11 changed files with 264 additions and 203 deletions.
128 changes: 0 additions & 128 deletions src/deepdisc/data_format/file_io.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import json
from pathlib import Path

import numpy as np
from astropy.visualization import make_lupton_rgb


def get_data_from_json(filename):
"""Open a JSON text file, and return encoded data as dictionary.
Expand All @@ -28,127 +24,3 @@ def get_data_from_json(filename):
with open(filename, "r", encoding="utf-8") as f:
data = json.load(f)
return data


class ImageReader:
"""Class that will read images on the fly for the training/testing dataloaders"""

def __init__(self, reader, norm="raw", **scalekwargs):
"""
Parameters
----------
reader : function
This function should take a single key and return a single image as a numpy array
ex) give a filename or an index in an array
norm : str
A contrast scaling to apply before data augmentation, i.e. luptonizing or z-score scaling
Default = raw
**scalekwargs : key word args
Key word args for the contrast scaling function
"""
self.reader = reader
self.scalekwargs = scalekwargs
self.scaling = ImageReader.norm_dict[norm]

def __call__(self, key):
"""Read the image and apply scaling.
Parameters
----------
key : str or int
The key indicating the image to read.
Returns
-------
im : numpy array
The image.
"""
im = self.reader(key)
im_scale = self.scaling(im, **self.scalekwargs)
return im_scale

def raw(im):
"""Apply raw image scaling (no scaling done).
Parameters
----------
im : numpy array
The image.
Returns
-------
numpy array
The image with pixels as float32.
"""
return im.astype(np.float32)

def lupton(im, bandlist=[2, 1, 0], stretch=0.5, Q=10, m=0):
"""Apply Lupton scaling to the image and return the scaled image.
Parameters
----------
im : np array
The image being scaled
bandlist : list[int]
Which bands to use for lupton scaling (must be 3)
stretch : float
lupton stretch parameter
Q : float
lupton Q parameter
m: float
lupton minimum parameter
Returns
-------
image : numpy array
The 3-channel image after lupton scaling using astropy make_lupton_rgb
"""
assert np.array(im.shape).argmin() == 2 and len(bandlist) == 3
b1 = im[:, :, bandlist[0]]
b2 = im[:, :, bandlist[1]]
b3 = im[:, :, bandlist[2]]

image = make_lupton_rgb(b1, b2, b3, minimum=m, stretch=stretch, Q=Q)
return image

def zscore(im, A=1):
"""Apply z-score scaling to the image and return the scaled image.
Parameters
----------
im : np array
The image being scaled
A : float
A multiplicative scaling factor applied to each band
Returns
-------
image : numpy array
The image after z-score scaling (subtract mean and divide by std deviation)
"""
I = np.mean(im, axis=-1)
Imean = np.nanmean(I)
Isigma = np.nanstd(I)

for i in range(im.shape[-1]):
image[:, :, i] = A * (im[:, :, i] - Imean - m) / Isigma

return image

#This dict is created to map an input string to a scaling function
norm_dict = {"raw": raw, "lupton": lupton}

@classmethod
def add_scaling(cls, name, func):
"""Add a custom contrast scaling function
ex)
def sqrt(image):
image[:,:,0] = np.sqrt(image[:,:,0])
image[:,:,1] = np.sqrt(image[:,:,1])
image[:,:,2] = np.sqrt(image[:,:,2])
return image
ImageReader.add_scaling('sqrt',sqrt)
"""
cls.norm_dict[name] = func
210 changes: 210 additions & 0 deletions src/deepdisc/data_format/image_readers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import abc
import os
import numpy as np

from astropy.io import fits
from astropy.visualization import make_lupton_rgb

class ImageReader(abc.ABC):
"""Base class that will read images on the fly for the training/testing dataloaders
To implement an image reader for a new class, the derived class needs to have an
__init__() function that calls super().__init__(*args, **kwargs)
and a custom version of _read_image().
"""

def __init__(self, norm="raw", *args, **kwargs):
"""
Parameters
----------
norm : str (optional)
A contrast scaling to apply before data augmentation, i.e. luptonizing or z-score scaling
Default = raw
**kwargs : key word args
Key word args for the contrast scaling function
"""
self.scaling = ImageReader.norm_dict[norm]
self.scalekwargs = kwargs

@abc.abstractmethod
def _read_image(self, key):
"""Read the image. No-op implementation.
Parameters
----------
key : str or int
The key indicating the image to read.
Returns
-------
im : numpy array
The image.
"""
pass

def __call__(self, key):
"""Read the image and apply scaling.
Parameters
----------
key : str or int
The key indicating the image to read.
Returns
-------
im : numpy array
The image.
"""
im = self._read_image(key)
im_scale = self.scaling(im, **self.scalekwargs)
return im_scale

def raw(im):
"""Apply raw image scaling (no scaling done).
Parameters
----------
im : numpy array
The image.
Returns
-------
numpy array
The image with pixels as float32.
"""
return im.astype(np.float32)

def lupton(im, bandlist=[2, 1, 0], stretch=0.5, Q=10, m=0):
"""Apply Lupton scaling to the image and return the scaled image.
Parameters
----------
im : np array
The image being scaled
bandlist : list[int]
Which bands to use for lupton scaling (must be 3)
stretch : float
lupton stretch parameter
Q : float
lupton Q parameter
m: float
lupton minimum parameter
Returns
-------
image : numpy array
The 3-channel image after lupton scaling using astropy make_lupton_rgb
"""
assert np.array(im.shape).argmin() == 2 and len(bandlist) == 3
b1 = im[:, :, bandlist[0]]
b2 = im[:, :, bandlist[1]]
b3 = im[:, :, bandlist[2]]

image = make_lupton_rgb(b1, b2, b3, minimum=m, stretch=stretch, Q=Q)
return image

def zscore(im, A=1):
"""Apply z-score scaling to the image and return the scaled image.
Parameters
----------
im : np array
The image being scaled
A : float
A multiplicative scaling factor applied to each band
Returns
-------
image : numpy array
The image after z-score scaling (subtract mean and divide by std deviation)
"""
I = np.mean(im, axis=-1)
Imean = np.nanmean(I)
Isigma = np.nanstd(I)

for i in range(im.shape[-1]):
image[:, :, i] = A * (im[:, :, i] - Imean - m) / Isigma

return image

#This dict is created to map an input string to a scaling function
norm_dict = {"raw": raw, "lupton": lupton}

@classmethod
def add_scaling(cls, name, func):
"""Add a custom contrast scaling function
ex)
def sqrt(image):
image[:,:,0] = np.sqrt(image[:,:,0])
image[:,:,1] = np.sqrt(image[:,:,1])
image[:,:,2] = np.sqrt(image[:,:,2])
return image
ImageReader.add_scaling('sqrt',sqrt)
"""
cls.norm_dict[name] = func


class DC2ImageReader(ImageReader):
"""An ImageReader for DC2 image files."""

def __init__(self, *args, **kwargs):
# Pass arguments to the parent function.
super().__init__(*args, **kwargs)

def _read_image(self, filename):
"""Read the image.
Parameters
----------
filename : str
The filename indicating the image to read.
Returns
-------
im : numpy array
The image.
"""
file = filename.split("/")[-1].split(".")[0]
base = os.path.dirname(filename)
fn = os.path.join(base, file) + ".npy"
image = np.load(fn)
image = np.transpose(image, axes=(1, 2, 0)).astype(np.float32)
return image


class HSCImageReader(ImageReader):
"""An ImageReader for HSC image files."""

def __init__(self, *args, **kwargs):
# Pass arguments to the parent function.
super().__init__(*args, **kwargs)

def _read_image(self, filenames):
"""Read the image.
Parameters
----------
filenames : list
A length 3 list of filenames for the I, R, and G images.
Returns
-------
im : numpy array
The image.
"""
if len(filenames) != 3:
raise ValueError("Incorrect number of filenames passed.")

g = fits.getdata(os.path.join(filenames[0]), memmap=False)
length, width = g.shape
image = np.empty([length, width, 3])
r = fits.getdata(os.path.join(filenames[1]), memmap=False)
i = fits.getdata(os.path.join(filenames[2]), memmap=False)

image[:, :, 0] = i
image[:, :, 1] = r
image[:, :, 2] = g
return image

4 changes: 2 additions & 2 deletions src/deepdisc/inference/predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def get_predictions(dataset_dict, imreader, key_mapper, predictor):
----------
dataset_dict : dictionary
The dictionary metadata for a single image
imreader: ImageReader class
The ImageReader used to load in images
imreader: ImageReader object
An object derived from ImageReader base class to read the images.
key_mapper: function
The key_mapper should take a dataset_dict as input and return the key used by imreader
predictor: AstroPredictor
Expand Down
Loading

0 comments on commit bedefa9

Please sign in to comment.