Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create ImageReader subclasses #53

Merged
merged 2 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 0 additions & 128 deletions src/deepdisc/data_format/file_io.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import json
from pathlib import Path

import numpy as np
from astropy.visualization import make_lupton_rgb


def get_data_from_json(filename):
"""Open a JSON text file, and return encoded data as dictionary.
Expand All @@ -28,127 +24,3 @@ def get_data_from_json(filename):
with open(filename, "r", encoding="utf-8") as f:
data = json.load(f)
return data


class ImageReader:
"""Class that will read images on the fly for the training/testing dataloaders"""

def __init__(self, reader, norm="raw", **scalekwargs):
"""
Parameters
----------
reader : function
This function should take a single key and return a single image as a numpy array
ex) give a filename or an index in an array
norm : str
A contrast scaling to apply before data augmentation, i.e. luptonizing or z-score scaling
Default = raw
**scalekwargs : key word args
Key word args for the contrast scaling function
"""
self.reader = reader
self.scalekwargs = scalekwargs
self.scaling = ImageReader.norm_dict[norm]

def __call__(self, key):
"""Read the image and apply scaling.
Parameters
----------
key : str or int
The key indicating the image to read.
Returns
-------
im : numpy array
The image.
"""
im = self.reader(key)
im_scale = self.scaling(im, **self.scalekwargs)
return im_scale

def raw(im):
"""Apply raw image scaling (no scaling done).
Parameters
----------
im : numpy array
The image.
Returns
-------
numpy array
The image with pixels as float32.
"""
return im.astype(np.float32)

def lupton(im, bandlist=[2, 1, 0], stretch=0.5, Q=10, m=0):
"""Apply Lupton scaling to the image and return the scaled image.
Parameters
----------
im : np array
The image being scaled
bandlist : list[int]
Which bands to use for lupton scaling (must be 3)
stretch : float
lupton stretch parameter
Q : float
lupton Q parameter
m: float
lupton minimum parameter
Returns
-------
image : numpy array
The 3-channel image after lupton scaling using astropy make_lupton_rgb
"""
assert np.array(im.shape).argmin() == 2 and len(bandlist) == 3
b1 = im[:, :, bandlist[0]]
b2 = im[:, :, bandlist[1]]
b3 = im[:, :, bandlist[2]]

image = make_lupton_rgb(b1, b2, b3, minimum=m, stretch=stretch, Q=Q)
return image

def zscore(im, A=1):
"""Apply z-score scaling to the image and return the scaled image.
Parameters
----------
im : np array
The image being scaled
A : float
A multiplicative scaling factor applied to each band
Returns
-------
image : numpy array
The image after z-score scaling (subtract mean and divide by std deviation)
"""
I = np.mean(im, axis=-1)
Imean = np.nanmean(I)
Isigma = np.nanstd(I)

for i in range(im.shape[-1]):
image[:, :, i] = A * (im[:, :, i] - Imean - m) / Isigma

return image

#This dict is created to map an input string to a scaling function
norm_dict = {"raw": raw, "lupton": lupton}

@classmethod
def add_scaling(cls, name, func):
"""Add a custom contrast scaling function
ex)
def sqrt(image):
image[:,:,0] = np.sqrt(image[:,:,0])
image[:,:,1] = np.sqrt(image[:,:,1])
image[:,:,2] = np.sqrt(image[:,:,2])
return image
ImageReader.add_scaling('sqrt',sqrt)
"""
cls.norm_dict[name] = func
210 changes: 210 additions & 0 deletions src/deepdisc/data_format/image_readers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import abc
import os
import numpy as np

from astropy.io import fits
from astropy.visualization import make_lupton_rgb

class ImageReader(abc.ABC):
"""Base class that will read images on the fly for the training/testing dataloaders
To implement an image reader for a new class, the derived class needs to have an
__init__() function that calls super().__init__(*args, **kwargs)
and a custom version of _read_image().
"""

def __init__(self, norm="raw", *args, **kwargs):
"""
Parameters
----------
norm : str (optional)
A contrast scaling to apply before data augmentation, i.e. luptonizing or z-score scaling
Default = raw
**kwargs : key word args
Key word args for the contrast scaling function
"""
self.scaling = ImageReader.norm_dict[norm]
self.scalekwargs = kwargs

@abc.abstractmethod
def _read_image(self, key):
"""Read the image. No-op implementation.
Parameters
----------
key : str or int
The key indicating the image to read.
Returns
-------
im : numpy array
The image.
"""
pass

def __call__(self, key):
"""Read the image and apply scaling.
Parameters
----------
key : str or int
The key indicating the image to read.
Returns
-------
im : numpy array
The image.
"""
im = self._read_image(key)
im_scale = self.scaling(im, **self.scalekwargs)
return im_scale

def raw(im):
"""Apply raw image scaling (no scaling done).
Parameters
----------
im : numpy array
The image.
Returns
-------
numpy array
The image with pixels as float32.
"""
return im.astype(np.float32)

def lupton(im, bandlist=[2, 1, 0], stretch=0.5, Q=10, m=0):
"""Apply Lupton scaling to the image and return the scaled image.
Parameters
----------
im : np array
The image being scaled
bandlist : list[int]
Which bands to use for lupton scaling (must be 3)
stretch : float
lupton stretch parameter
Q : float
lupton Q parameter
m: float
lupton minimum parameter
Returns
-------
image : numpy array
The 3-channel image after lupton scaling using astropy make_lupton_rgb
"""
assert np.array(im.shape).argmin() == 2 and len(bandlist) == 3
b1 = im[:, :, bandlist[0]]
b2 = im[:, :, bandlist[1]]
b3 = im[:, :, bandlist[2]]

image = make_lupton_rgb(b1, b2, b3, minimum=m, stretch=stretch, Q=Q)
return image

def zscore(im, A=1):
"""Apply z-score scaling to the image and return the scaled image.
Parameters
----------
im : np array
The image being scaled
A : float
A multiplicative scaling factor applied to each band
Returns
-------
image : numpy array
The image after z-score scaling (subtract mean and divide by std deviation)
"""
I = np.mean(im, axis=-1)
Imean = np.nanmean(I)
Isigma = np.nanstd(I)

for i in range(im.shape[-1]):
image[:, :, i] = A * (im[:, :, i] - Imean - m) / Isigma

return image

#This dict is created to map an input string to a scaling function
norm_dict = {"raw": raw, "lupton": lupton}

@classmethod
def add_scaling(cls, name, func):
"""Add a custom contrast scaling function
ex)
def sqrt(image):
image[:,:,0] = np.sqrt(image[:,:,0])
image[:,:,1] = np.sqrt(image[:,:,1])
image[:,:,2] = np.sqrt(image[:,:,2])
return image
ImageReader.add_scaling('sqrt',sqrt)
"""
cls.norm_dict[name] = func


class DC2ImageReader(ImageReader):
"""An ImageReader for DC2 image files."""

def __init__(self, *args, **kwargs):
# Pass arguments to the parent function.
super().__init__(*args, **kwargs)

def _read_image(self, filename):
"""Read the image.
Parameters
----------
filename : str
The filename indicating the image to read.
Returns
-------
im : numpy array
The image.
"""
file = filename.split("/")[-1].split(".")[0]
base = os.path.dirname(filename)
fn = os.path.join(base, file) + ".npy"
image = np.load(fn)
image = np.transpose(image, axes=(1, 2, 0)).astype(np.float32)
return image


class HSCImageReader(ImageReader):
"""An ImageReader for HSC image files."""

def __init__(self, *args, **kwargs):
# Pass arguments to the parent function.
super().__init__(*args, **kwargs)

def _read_image(self, filenames):
"""Read the image.
Parameters
----------
filenames : list
A length 3 list of filenames for the I, R, and G images.
Returns
-------
im : numpy array
The image.
"""
if len(filenames) != 3:
raise ValueError("Incorrect number of filenames passed.")

g = fits.getdata(os.path.join(filenames[0]), memmap=False)
length, width = g.shape
image = np.empty([length, width, 3])
r = fits.getdata(os.path.join(filenames[1]), memmap=False)
i = fits.getdata(os.path.join(filenames[2]), memmap=False)

image[:, :, 0] = i
image[:, :, 1] = r
image[:, :, 2] = g
return image

4 changes: 2 additions & 2 deletions src/deepdisc/inference/predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def get_predictions(dataset_dict, imreader, key_mapper, predictor):
----------
dataset_dict : dictionary
The dictionary metadata for a single image
imreader: ImageReader class
The ImageReader used to load in images
imreader: ImageReader object
An object derived from ImageReader base class to read the images.
key_mapper: function
The key_mapper should take a dataset_dict as input and return the key used by imreader
predictor: AstroPredictor
Expand Down
Loading
Loading