Skip to content

Commit

Permalink
Merge pull request #123 from grantmerz/val_loss
Browse files Browse the repository at this point in the history
add new hooks
  • Loading branch information
grantmerz authored Jan 10, 2025
2 parents 4b46762 + 143fb78 commit 6dacb41
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 10 deletions.
55 changes: 48 additions & 7 deletions src/deepdisc/astrodet/detectron.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,37 @@ def after_train(self):
self.trainer.checkpointer.save(self.output_name) # Note: Set the name of the output model here




class NewSaveHook(HookBase):

"""
This Hook saves the model during training
"""

output_name = "model_temp"


def __init__(self, save_period):
self._period = save_period

def set_output_name(self, name):
self.output_name = name

#def after_train(self):
# self.trainer.checkpointer.save(self.output_name) # Note: Set the name of the output model here

def after_step(self):
next_iter = self.trainer.iter + 1
is_final = next_iter == self.trainer.max_iter
if is_final or (self._period > 0 and next_iter % self._period == 0): # or (next_iter == 1):
print("saving", self.output_name)
self.trainer.checkpointer.save(f'{self.output_name}_{next_iter//self._period}')
if is_final:
self.trainer.checkpointer.save(self.output_name)


#
class LossEvalHook(HookBase):

Expand Down Expand Up @@ -157,6 +188,7 @@ def _do_loss_eval(self):
start_time = time.perf_counter()
total_compute_time = 0
losses = []
losses_dicts =[]
with torch.no_grad():
for idx, inputs in enumerate(self._data_loader):
if idx == num_warmup:
Expand All @@ -178,29 +210,38 @@ def _do_loss_eval(self):
),
n=5,
)
loss_batch = self._get_loss(inputs)
loss_batch, metrics_dict = self._get_loss(inputs)
losses.append(loss_batch)
losses_dicts.append(metrics_dict)
mean_loss = np.mean(losses)
averaged_losses_dict ={}
for d in losses_dicts:
for key, value in d.items():
if key not in averaged_losses_dict:
averaged_losses_dict[key] = [0, 0] # [sum, count]
averaged_losses_dict[key][0] += value
averaged_losses_dict[key][1] += 1
averaged_losses_dict = {key: total / count for key, (total, count) in averaged_losses_dict.items()}
# print('validation_loss', mean_loss)
self.trainer.storage.put_scalar("validation_loss", mean_loss)
self.trainer.add_val_loss(mean_loss)
self.trainer.valloss = mean_loss

self.trainer.add_val_loss_dict(averaged_losses_dict)
self.trainer.vallossdict = averaged_losses_dict

comm.synchronize()
return losses

def _get_loss(self, data):
# How loss is calculated on train_loop
try:
metrics_dict = self._model(data)
except:
print("Check the size of the images in the validation set")
return 0
metrics_dict = self._model(data)
metrics_dict = {
k: v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
for k, v in metrics_dict.items()
}
total_losses_reduced = sum(loss for loss in metrics_dict.values())
return total_losses_reduced
return total_losses_reduced, metrics_dict

def after_step(self):
next_iter = self.trainer.iter + 1
Expand Down
126 changes: 123 additions & 3 deletions src/deepdisc/training/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,24 @@ def __init__(self, model, data_loader, optimizer, cfg):
self.checkpointer = checkpointer.DetectionCheckpointer(
# Assume you want to save checkpoints together with logs/statistics
model,
cfg.OUTPUT_DIR, # save checkpoint with loss_list
cfg.OUTPUT_DIR,
)
# load weights
self.checkpointer.load(cfg.train.init_checkpoint)

# record loss over iteration
self.lossList = []
self.lossdict_epochs = {}
self.vallossList = []
self.vallossdict_epochs = {}

self.period = 20
self.iterCount = 0

self.scheduler = self.build_lr_scheduler(cfg, optimizer)
# self.scheduler = instantiate(cfg.lr_multiplier)
self.valloss = 0
self.vallossdict={}

# Note: print out loss over p iterations
def set_period(self, p):
Expand All @@ -53,6 +56,14 @@ def run_step(self):
loss_dict = self.model(data)
loss_time = time.perf_counter() - start


ld = {
k: v.detach().cpu().item() if (isinstance(v, torch.Tensor) and v.numel()==1) else v.tolist()
for k, v in loss_dict.items()
}

self.lossdict_epochs[str(self.iterCount)] = ld

# print('Loss dict',loss_dict)
if isinstance(loss_dict, torch.Tensor):
losses = loss_dict
Expand Down Expand Up @@ -105,6 +116,114 @@ def add_val_loss(self, val_loss):

self.vallossList.append(val_loss)

def add_val_loss_dict(self, val_loss_dict):
"""
It now calls :func:`detectron2.solver.build_lr_scheduler`.
Overwrite it if you'd like a different scheduler.
"""

self.vallossdict_epochs[str(self.iterCount)] = val_loss_dict


class LazyAstroEvaluator(SimpleTrainer):
def __init__(self, model, data_loader, optimizer, cfg):
super().__init__(model, data_loader, optimizer)

# Borrowed from DefaultTrainer constructor
# see https://detectron2.readthedocs.io/en/latest/_modules/detectron2/engine/defaults.html#DefaultTrainer
self.checkpointer = checkpointer.DetectionCheckpointer(
# Assume you want to save checkpoints together with logs/statistics
model,
cfg.OUTPUT_DIR,
)
# load weights
self.checkpointer.load(cfg.train.init_checkpoint)

# record loss over iteration
self.lossList = []
self.lossdict_epochs = {}
self.vallossList = []
self.vallossdict_epochs = {}

self.period = 20
self.iterCount = 0

self.scheduler = self.build_lr_scheduler(cfg, optimizer)
# self.scheduler = instantiate(cfg.lr_multiplier)
self.valloss = 0
self.vallossdict={}

# Note: print out loss over p iterations
def set_period(self, p):
self.period = p

# Copied directly from SimpleTrainer, add in custom manipulation with the loss
# see https://detectron2.readthedocs.io/en/latest/_modules/detectron2/engine/train_loop.html#SimpleTrainer
def run_step(self):
# Copying inference_on_dataset from evaluator.py
#total = len(self.data_loader)
#num_warmup = min(5, total - 1)

start_time = time.perf_counter()
total_compute_time = 0
losses = []
losses_dicts =[]
with torch.no_grad():
for idx, inputs in enumerate(self.data_loader):
'''
if idx == num_warmup:
start_time = time.perf_counter()
total_compute_time = 0
start_compute_time = time.perf_counter()
if torch.cuda.is_available():
torch.cuda.synchronize()
total_compute_time += time.perf_counter() - start_compute_time
iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
seconds_per_img = total_compute_time / iters_after_start
if idx >= num_warmup * 2 or seconds_per_img > 5:
total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
log_every_n_seconds(
logging.INFO,
"Loss on Validation done {}/{}. {:.4f} s / img. ETA={}".format(
idx + 1, total, seconds_per_img, str(eta)
),
n=5,
)
'''
metrics_dict = self.model(inputs)
losses_dicts.append(metrics_dict)

#losses.append(loss_batch)
#losses_dicts.append(metrics_dict)
#mean_loss = np.mean(losses)
self.losses_dicts = losses_dicts


@classmethod
def build_lr_scheduler(cls, cfg, optimizer):
"""
It now calls :func:`detectron2.solver.build_lr_scheduler`.
Overwrite it if you'd like a different scheduler.
"""
return build_lr_scheduler(cfg, optimizer)

def add_val_loss(self, val_loss):
"""
It now calls :func:`detectron2.solver.build_lr_scheduler`.
Overwrite it if you'd like a different scheduler.
"""

self.vallossList.append(val_loss)

def add_val_loss_dict(self, val_loss_dict):
"""
It now calls :func:`detectron2.solver.build_lr_scheduler`.
Overwrite it if you'd like a different scheduler.
"""

self.vallossdict_epochs[str(self.iterCount)] = val_loss_dict


def return_lazy_trainer(model, loader, optimizer, cfg, hooklist):
"""Return a trainer for models built on LazyConfigs
Expand Down Expand Up @@ -132,7 +251,7 @@ def return_lazy_trainer(model, loader, optimizer, cfg, hooklist):
return trainer


def return_savehook(output_name):
def return_savehook(output_name, save_period):
"""Returns a hook for saving the model
Parameters
Expand All @@ -144,7 +263,8 @@ def return_savehook(output_name):
-------
a SaveHook
"""
saveHook = detectron_addons.SaveHook()
#saveHook = detectron_addons.SaveHook()
saveHook = detectron_addons.NewSaveHook(save_period)
saveHook.set_output_name(output_name)
return saveHook

Expand Down

0 comments on commit 6dacb41

Please sign in to comment.