Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Configs follow up #66

Merged
merged 25 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
8a17117
Rename defaults file
OliviaLynn Nov 17, 2023
321881a
Start mvitv, clean up imports
OliviaLynn Nov 17, 2023
a57c931
Update ref
OliviaLynn Nov 17, 2023
a8f5735
Add dir for solo run scripts
OliviaLynn Nov 17, 2023
9ab8c16
Add dc2 support
OliviaLynn Nov 27, 2023
d6068db
Undo switched names
OliviaLynn Nov 27, 2023
017bcaa
Add DC2_redshift run
OliviaLynn Nov 27, 2023
9e670bc
Remove added DC2 comments
OliviaLynn Nov 27, 2023
cb6c9db
Fix flipped dtype
OliviaLynn Nov 27, 2023
7c8ab45
test_eval_model first pass; start removing cfg_node from predictors &…
OliviaLynn Nov 28, 2023
25cdecf
Predictors mostly updated
OliviaLynn Dec 1, 2023
b178d8f
Loaders docstrings
OliviaLynn Dec 1, 2023
474afa5
Cover trainers
OliviaLynn Dec 1, 2023
9a00b54
More predictors; clean ups
OliviaLynn Dec 1, 2023
1f3bdaf
Predictors branch 2
OliviaLynn Dec 1, 2023
5ab5681
Spacing
OliviaLynn Dec 4, 2023
f579988
Add .map_data to solo run file for DC2 (as in #72)
OliviaLynn Dec 5, 2023
b3c66d4
Improve consistency in language used in configs
OliviaLynn Dec 5, 2023
c4a12e3
Apply black to configs
OliviaLynn Dec 6, 2023
336b087
Docstrings, remove old comments
OliviaLynn Dec 6, 2023
6ff26f1
Update solo_test_run_transformers_DC2_redshift.py
OliviaLynn Dec 6, 2023
6cfaca1
Update solo_run_scripts/solo_test_eval_model.py
OliviaLynn Dec 7, 2023
c083438
Update solo_run_scripts/solo_test_eval_model.py
OliviaLynn Dec 7, 2023
7665664
Update solo_run_scripts/solo_test_eval_model.py
OliviaLynn Dec 7, 2023
7948791
PR comments
OliviaLynn Dec 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions solo_run_scripts/solo_test_eval_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""
This code will read in a trained model and output the classes for predicted objects matched to the ground truth

"""
import logging
import os
import time

import numpy as np
import deepdisc.astrodet.astrodet as toolkit

from deepdisc.data_format.file_io import get_data_from_json
from deepdisc.data_format.image_readers import HSCImageReader
from deepdisc.inference.match_objects import get_matched_object_classes
from deepdisc.inference.predictors import return_predictor_transformer
from deepdisc.utils.parse_arguments import dtype_from_args, make_inference_arg_parser

from detectron2 import model_zoo
#from detectron2.config import get_cfg, LazyConfig
OliviaLynn marked this conversation as resolved.
Show resolved Hide resolved
from detectron2.config import LazyConfig
from detectron2.data import MetadataCatalog
from detectron2.utils.logger import setup_logger

from pathlib import Path

setup_logger()
logger = logging.getLogger(__name__)

# Inference should use the config with parameters that are used in training
# cfg now already contains everything we've set previously. We changed it a little bit for inference:

def return_predictor(
cfgfile, run_name, nc=1, output_dir="/home/shared/hsc/HSC/HSC_DR3/models/noclass/", roi_thresh=0.5
):
"""
This function returns a trained model and its config file.
Used for models that have yacs config files

Parameters
----------
cfgfile: str
A path to a model config file, provided by the detectron2 repo
run_name: str
Prefix used for the name of the saved model
nc: int
Number of classes used in the model
OliviaLynn marked this conversation as resolved.
Show resolved Hide resolved
output_dir: str
THe directory to save metric outputs
roi_thresh: float
Hyperparamter that functions as a detection sensitivity level

OliviaLynn marked this conversation as resolved.
Show resolved Hide resolved
"""

#cfg = get_cfg()
#cfg.merge_from_file(model_zoo.get_config_file(cfgfile)) # Get model structure
OliviaLynn marked this conversation as resolved.
Show resolved Hide resolved

cfg = LazyConfig.load(cfgfile)

cfg.MODEL.ROI_HEADS.NUM_CLASSES = nc
cfg.OUTPUT_DIR = output_dir
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, run_name) # path to the model we just trained
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = roi_thresh # set a custom testing threshold

predictor = toolkit.AstroPredictor(cfg)

return predictor, cfg


# Inference should use the config with parameters that are used in training
# cfg now already contains everything we've set previously. We changed it a little bit for inference:
OliviaLynn marked this conversation as resolved.
Show resolved Hide resolved

if __name__ == "__main__":
# --------- Handle args
args = make_inference_arg_parser().parse_args()
roi_thresh = args.roi_thresh
run_name = args.run_name
testfile = args.testfile
savedir = args.savedir
Path(savedir).mkdir(parents=True, exist_ok=True)
output_dir = args.output_dir
dtype=dtype_from_args(args.datatype)


# --------- Load data
dataset_names = ["test"]
datadir = "/home/shared/hsc/HSC/HSC_DR3/data/"
t0 = time.time()
dataset_dicts = {}
for i, d in enumerate(dataset_names):
dataset_dicts[d] = get_data_from_json(testfile)
print("Took ", time.time() - t0, "seconds to load samples")

# Local vars/metadata
#classes = ["star", "galaxy"]
bb = args.run_name.split("_")[0]
OliviaLynn marked this conversation as resolved.
Show resolved Hide resolved

# --------- Start config stuff
cfgfile = (
f"./tests/deepdisc/test_data/configs/"
f"solo/solo_cascade_mask_rcnn_swin_b_in21k_50ep_test_eval.py"
)
cfg = LazyConfig.load(cfgfile)

# --------- Setting a bunch of config stuff
cfg.model.roi_heads.num_classes = args.nc

for bp in cfg.model.roi_heads.box_predictors:
bp.test_score_thresh = roi_thresh

for box_predictor in cfg.model.roi_heads.box_predictors:
box_predictor.test_topk_per_image = 1000
box_predictor.test_score_thresh = roi_thresh

cfg.train.init_checkpoint = os.path.join(output_dir, run_name)

# --------- Now we case predictor on model type (the second case has way different config vals it appears)

cfg.OUTPUT_DIR = output_dir
if bb in ['Swin','MViTv2']:
predictor= return_predictor_transformer(cfg)
else:
cfgfile = "./tests/deepdisc/test_data/configs/solo/solo_test_eval_model_option.py"
predictor, cfg = return_predictor(cfgfile, run_name, output_dir=output_dir, nc=2, roi_thresh=roi_thresh)

# ---------
def hsc_key_mapper(dataset_dict):
filenames = [
dataset_dict["filename_G"],
dataset_dict["filename_R"],
dataset_dict["filename_I"],
]
return filenames
IR = HSCImageReader(norm=args.norm)

# --------- Do the thing
t0 = time.time()
print("Matching objects")
true_classes, pred_classes = get_matched_object_classes(dataset_dicts["test"], IR, hsc_key_mapper, predictor)
classes = np.array([true_classes, pred_classes])

savename = f"{bb}_test_matched_classes.npy"
np.save(os.path.join(args.savedir, savename), classes)

print("Took ", time.time() - t0, " seconds")
print(classes)
t0 = time.time()
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,19 @@ def main(train_head, args):
scheme = args.scheme
alphas = args.alphas
modname = args.modname
dtype = dtype_from_args(args.dtype)

# Get file locations
trainfile = dirpath + "single_test.json"
testfile = dirpath + "single_test.json"
if modname == "swin":
cfgfile = "./tests/deepdisc/test_data/configs/solo/solo_cascade_mask_rcnn_swin_b_in21k_50ep.py"
# initwfile = "/home/shared/hsc/detectron2/projects/ViTDet/model_final_246a82.pkl"
elif modname == "mvitv2":
cfgfile = "/home/shared/hsc/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_mvitv2_b_in21k_100ep.py"
# initwfile = "/home/shared/hsc/detectron2/projects/ViTDet/model_final_8c3da3.pkl"
elif modname == "vitdet":
cfgfile = "/home/shared/hsc/detectron2/projects/ViTDet/configs/COCO/mask_rcnn_vitdet_b_100ep.py"
# initwfile = '/home/g4merz/deblend/detectron2/projects/ViTDet/model_final_435fa9.pkl'
# initwfile = "/home/shared/hsc/detectron2/projects/ViTDet/model_final_61ccd1.pkl"
dtype = dtype_from_args(args.dtype)
trainfile = dirpath + "single_test.json"
testfile = dirpath + "single_test.json"

cfgfile = "./tests/deepdisc/test_data/configs/solo/solo_cascade_mask_rcnn_mvitv2_b_in21k_100ep.py"
# Vitdet not currently available (cuda issues) so we're tabling it for now
#elif modname == "vitdet":
OliviaLynn marked this conversation as resolved.
Show resolved Hide resolved
# cfgfile = "/home/shared/hsc/detectron2/projects/ViTDet/configs/COCO/mask_rcnn_vitdet_b_100ep.py"

# Load the config
cfg = LazyConfig.load(cfgfile)

Expand All @@ -91,6 +90,7 @@ def main(train_head, args):
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

# Iterations for 15, 25, 35, 50 epochs
# TODOLIV could this stuff be moved to a config too?
epoch = int(args.tl / cfg.dataloader.train.total_batch_size)
e1 = 20
e2 = epoch * 10
Expand Down Expand Up @@ -135,7 +135,7 @@ def hsc_key_mapper(dataset_dict):
schedulerHook = return_schedulerhook(optimizer)
hookList = [lossHook, schedulerHook, saveHook]

trainer = return_lazy_trainer(model, loader, optimizer, cfg, cfg, hookList)
trainer = return_lazy_trainer(model, loader, optimizer, cfg, hookList)

trainer.set_period(5)
trainer.train(0, 20)
Expand Down
170 changes: 170 additions & 0 deletions solo_run_scripts/solo_test_run_transformers_DC2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
""" Training script for LazyConfig models.

This uses the new "solo config" in which the previous yaml-style config
(a Detectron CfgNode type called cfg_loader) is now bundled into the
LazyConfig type cfg.
"""

try:
# ignore ShapelyDeprecationWarning from fvcore
import warnings
from shapely.errors import ShapelyDeprecationWarning
warnings.filterwarnings("ignore", category=sShapelyDeprecationWarning)
except:
pass
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UserWarning)

# Some basic setup:
# Setup detectron2 logger
from detectron2.utils.logger import setup_logger
setup_logger()

import gc
import os
import time

import detectron2.utils.comm as comm

# import some common libraries
import numpy as np
import torch

# import some common detectron2 utilities
from detectron2.config import LazyConfig, get_cfg
from detectron2.engine import launch

from deepdisc.data_format.augment_image import train_augs
from deepdisc.data_format.image_readers import DC2ImageReader
from deepdisc.data_format.register_data import register_data_set
from deepdisc.model.loaders import DictMapper, return_test_loader, return_train_loader
from deepdisc.model.models import return_lazy_model
from deepdisc.training.trainers import (
return_evallosshook,
return_lazy_trainer,
return_optimizer,
return_savehook,
return_schedulerhook,
)
from deepdisc.utils.parse_arguments import make_training_arg_parser


def main(train_head, args):
# Hack if you get SSL certificate error
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

# Handle args
output_dir = args.output_dir
output_name = args.run_name
dirpath = args.data_dir # Path to dataset
scheme = args.scheme
alphas = args.alphas
modname = args.modname
datatype = args.dtype
if datatype == 8:
dtype = np.uint8
elif datatype == 16:
dtype = np.int16

# Get file locations
trainfile = dirpath + "single_test.json"
testfile = dirpath + "single_test.json"
if modname == "swin":
cfgfile = "./tests/deepdisc/test_data/configs/solo/solo_cascade_mask_rcnn_swin_b_in21k_50ep_DC2.py"
elif modname == "mvitv2":
cfgfile = "./tests/deepdisc/test_data/configs/solo/solo_cascade_mask_rcnn_mvitv2_b_in21k_100ep_DC2.py"
# Vitdet not currently available (cuda issues) so we're tabling it for now
#elif modname == "vitdet":
# cfgfile = "/home/shared/hsc/detectron2/projects/ViTDet/configs/COCO/mask_rcnn_vitdet_b_100ep.py"

# Load the config
cfg = LazyConfig.load(cfgfile)

# Register the data sets
astrotrain_metadata = register_data_set(
cfg.DATASETS.TRAIN, trainfile, thing_classes=cfg.metadata.classes
)
astroval_metadata = register_data_set(
cfg.DATASETS.TEST, testfile, thing_classes=cfg.metadata.classes
)

# Set the output directory
cfg.OUTPUT_DIR = output_dir
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

# Iterations for 15, 25, 35, 50 epochs
# TODOLIV could this stuff be moved to a config too?
epoch = int(args.tl / cfg.dataloader.train.total_batch_size)
e1 = 20
e2 = epoch * 10
e3 = epoch * 20
efinal = epoch * 35

val_per = 5

if train_head:
cfg.train.init_checkpoint = None # or initwfile, the path to your model

model = return_lazy_model(cfg)

cfg.optimizer.params.model = model
cfg.optimizer.lr = 0.001

cfg.SOLVER.STEPS = [] # do not decay learning rate for retraining
cfg.SOLVER.LR_SCHEDULER_NAME = "WarmupMultiStepLR"
cfg.SOLVER.WARMUP_ITERS = 0
cfg.SOLVER.MAX_ITER = e1 # for DefaultTrainer

# optimizer = instantiate(cfg.optimizer)
optimizer = return_optimizer(cfg)

def dc2_key_mapper(dataset_dict):
filename = dataset_dict["filename"]
return filename

IR = DC2ImageReader(norm=args.norm)
mapper = DictMapper(IR, dc2_key_mapper, train_augs).map_data
loader = return_train_loader(cfg, mapper)
test_mapper = DictMapper(IR, dc2_key_mapper).map_data
test_loader = return_test_loader(cfg, test_mapper)

saveHook = return_savehook(output_name)
lossHook = return_evallosshook(val_per, model, test_loader)
schedulerHook = return_schedulerhook(optimizer)
hookList = [lossHook, schedulerHook, saveHook]

trainer = return_lazy_trainer(model, loader, optimizer, cfg, hookList)

trainer.set_period(5)
trainer.train(0, 20)
if comm.is_main_process():
np.save(output_dir + output_name + "_losses", trainer.lossList)
np.save(output_dir + output_name + "_val_losses", trainer.vallossList)
return


if __name__ == "__main__":
args = make_training_arg_parser().parse_args()
print("Command Line Args:", args)

print("Training head layers")
train_head = True
t0 = time.time()
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(
train_head,
args,
),
)

torch.cuda.empty_cache()
gc.collect()

print(f"Took {time.time()-t0} seconds")

Loading