Skip to content

Commit

Permalink
Merge pull request #117 from lincc-frameworks/roman_data
Browse files Browse the repository at this point in the history
Added RomanImageReader and create new swin config file for Roman data
  • Loading branch information
grantmerz authored Sep 17, 2024
2 parents b11c476 + e505a33 commit ac0176b
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 3 deletions.
104 changes: 104 additions & 0 deletions configs/solo/solo_swin_Roman.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
""" This is a demo "solo config" file for use in solo_test_run_transformers.py.
This uses template configs cascade_mask_rcnn_swin_b_in21k_50ep and yaml_style_defaults."""

from omegaconf import OmegaConf
import numpy as np
# ---------------------------------------------------------------------------- #
# Local variables and metadata
# ---------------------------------------------------------------------------- #
epoch=64 #(NUM OF IMAGES / BATCH SIZE 280/4)
bs=4 # 1 or 2 times the number of gpus
metadata = OmegaConf.create()
metadata.classes = ["galaxy", "star", "no_match"]

numclasses = len(metadata.classes)

# ---------------------------------------------------------------------------- #
# Standard config (this has always been the LazyConfig/.py-style config)
# ---------------------------------------------------------------------------- #
# Get values from templates
from ..COCO.cascade_mask_rcnn_swin_b_in21k_50ep import dataloader, model, train, lr_multiplier, optimizer
import deepdisc.model.loaders as loaders
from deepdisc.data_format.augment_image import dc2_train_augs, dc2_train_augs_full
from deepdisc.data_format.image_readers import RomanImageReader

# Overrides
dataloader.augs = dc2_train_augs
dataloader.train.total_batch_size = bs

model.proposal_generator.anchor_generator.sizes = [[8], [16], [32], [64], [128]]
model.roi_heads.num_classes = numclasses
model.roi_heads.batch_size_per_image = 512

model.roi_heads.num_classes = numclasses
model.roi_heads.batch_size_per_image = 512
model.backbone.bottom_up.in_chans = 4
model.pixel_mean = [
0.72566669,
1.08332919,
0.89876418,
1.081059
]
model.pixel_std = [
21.29308954,
28.89818123,
27.25925706,
28.90273031
]

#model.roi_heads.num_components = 3
#model.roi_heads.zloss_factor = 1
#model.roi_heads.zbins = np.linspace(0,5,200)
#model.roi_heads.weights = np.load('/home/g4merz/rail_deepdisc/configs/solo/zweights.npy')
#model.roi_heads._target_ = RedshiftPDFCasROIHeads

model.proposal_generator.nms_thresh = 0.3

for box_predictor in model.roi_heads.box_predictors:
box_predictor.test_topk_per_image = 2000
box_predictor.test_score_thresh = 0.5
box_predictor.test_nms_thresh = 0.3

train.init_checkpoint = "./detectron2/projects/ViTDet/model_final_246a82.pkl"

optimizer.lr = 0.001
dataloader.test.mapper = loaders.DictMapper
dataloader.train.mapper = loaders.DictMapper
reader = RomanImageReader()
dataloader.imagereader = reader
dataloader.epoch=epoch
# ---------------------------------------------------------------------------- #
# Yaml-style config (was formerly saved as a .yaml file, loaded to cfg_loader)
# ---------------------------------------------------------------------------- #
# Get values from template
from .yacs_style_defaults import MISC, DATALOADER, DATASETS, GLOBAL, INPUT, MODEL, SOLVER, TEST

# Overrides
SOLVER.IMS_PER_BATCH = bs

DATASETS.TRAIN = "astro_train"
DATASETS.TEST = "astro_val"

SOLVER.BASE_LR = 0.001
SOLVER.CLIP_GRADIENTS.ENABLED = True
# Type of gradient clipping, currently 2 values are supported:
# - "value": the absolute values of elements of each gradients are clipped
# - "norm": the norm of the gradient for each parameter is clipped thus
# affecting all elements in the parameter
SOLVER.CLIP_GRADIENTS.CLIP_TYPE = "norm"
# Maximum absolute value used for clipping gradients
# Floating point number p for L-p norm to be used with the "norm"
# gradient clipping type; for L-inf, please specify .inf
SOLVER.CLIP_GRADIENTS.NORM_TYPE = 5.0


e1 = epoch * 15
e2 = epoch * 25
e3 = epoch * 30
efinal = epoch * 50

SOLVER.STEPS = [e1,e2,e3] # do not decay learning rate for retraining
SOLVER.LR_SCHEDULER_NAME = "WarmupMultiStepLR"
SOLVER.WARMUP_ITERS = 0
SOLVER.MAX_ITER = efinal # for DefaultTrainer
25 changes: 24 additions & 1 deletion src/deepdisc/data_format/image_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def _read_image(self, filename):
image = np.transpose(image, axes=(1, 2, 0)).astype(np.float32)
return image


class HSCImageReader(ImageReader):
"""An ImageReader for HSC image files."""

Expand Down Expand Up @@ -210,3 +209,27 @@ def _read_image(self, filenames):
image[:, :, 1] = r
image[:, :, 2] = g
return image

class RomanImageReader(ImageReader):
"""An ImageReader for Roman image files."""

def __init__(self, *args, **kwargs):
# Pass arguments to the parent function.
super().__init__(*args, **kwargs)

def _read_image(self, filename):
"""Read the image.
Parameters
----------
filename : str
The filename indicating the image to read.
Returns
-------
im : numpy array
The image.
"""
image = np.load(filename) # (4, 512, 512)
image = np.transpose(image, axes=(1, 2, 0)).astype(np.float32) # (512, 512, 4)
return image
1 change: 0 additions & 1 deletion src/deepdisc/inference/predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def get_predictions(dataset_dict, imreader, key_mapper, predictor):
outputs: list(Intances)
The list of detected object Instances
"""

key = key_mapper(dataset_dict)
img = imreader(key)
outputs = predictor(img)
Expand Down
2 changes: 1 addition & 1 deletion src/deepdisc/training/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, model, data_loader, optimizer, cfg):
self.checkpointer = checkpointer.DetectionCheckpointer(
# Assume you want to save checkpoints together with logs/statistics
model,
cfg.OUTPUT_DIR,
cfg.OUTPUT_DIR, # save checkpoint with loss_list
)
# load weights
self.checkpointer.load(cfg.train.init_checkpoint)
Expand Down
7 changes: 7 additions & 0 deletions tests/deepdisc/data_format/test_image_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ def test_read_dc2_data(dc2_single_test_file):
assert img.shape[1] == 525
assert img.shape[2] == 6

def test_read_roman_data(roman_single_test_file):
"""Test that we can read the test Roman data."""
ir = RomanImageReader(norm="raw")
img = ir(dc2_single_test_file)
assert img.shape[0] == 525
assert img.shape[1] == 525
assert img.shape[2] == 4

def test_lupton_base_case(dc2_single_test_file):
"""Test that we can call lupton scaling without crashing"""
Expand Down

0 comments on commit ac0176b

Please sign in to comment.