Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
YazhouZhu19 authored Mar 14, 2024
1 parent 69b7835 commit 93d6712
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from dataloaders.datasets import TrainDataset as TrainDataset
from utils import *
from config import ex

from losses import *


Expand Down Expand Up @@ -94,7 +93,7 @@ def main(_run, _config, _log):

i_iter = 0
_log.info(f'Start training...')

eta = 1.
for sub_epoch in range(n_sub_epochs):
_log.info(f'This is epoch "{sub_epoch}" of "{n_sub_epochs}" epochs.')
Expand All @@ -110,16 +109,16 @@ def main(_run, _config, _log):
query_labels = torch.cat([query_label.long().cuda() for query_label in sample['query_labels']], dim=0)

# Compute outputs and losses.
query_pred = model(support_images, support_fg_mask, query_images, train=True)
query_pred, periphery_loss, align_loss, mse_loss, qry_loss = model(support_images, support_fg_mask,
query_images, query_labels, train=True)

query_loss = criterion(torch.log(torch.clamp(query_pred, torch.finfo(torch.float32).eps,
1 - torch.finfo(torch.float32).eps)), query_labels)

bd_loss = criterion_bd(query_pred, query_labels)

dice_loss = criterion_dice(query_pred, query_labels)
# bd_loss = criterion_bd(query_pred, query_labels)
# dice_loss = criterion_dice(query_pred, query_labels)

loss = query_loss + (1 - eta) * bd_loss + eta * dice_loss
loss = query_loss + 0.1 * periphery_loss + align_loss + 0.1 * mse_loss + qry_loss

# Compute gradient and do SGD step.
for param in model.parameters():
Expand Down Expand Up @@ -156,9 +155,8 @@ def main(_run, _config, _log):
os.path.join(f'{_run.observers[0].dir}/snapshots', f'{i_iter + 1}.pth'))

i_iter += 1

eta = eta - 0.01



_log.info('End of training.')
return 1

0 comments on commit 93d6712

Please sign in to comment.