Skip to content

Commit

Permalink
Adds Detr sanity fns (#51)
Browse files Browse the repository at this point in the history
* Adds Detr sanity fns

* bug fixes
  • Loading branch information
oke-aditya authored Nov 23, 2020
1 parent 0d1ee43 commit 7f61612
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 14 deletions.
149 changes: 143 additions & 6 deletions quickvision/models/detection/detr/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,150 @@ def fit(model: nn.Module, epochs: int, train_loader,
return history


def train_sanity_fit():
pass
def train_sanity_fit(model: nn.Module, train_loader, criterion, device: str,
num_batches: int = None, log_interval: int = 100, fp16: bool = False,):
"""
Performs Sanity fit over train loader.
Use this to dummy check your train_step function. It does not calculate metrics, timing, or does checkpointing.
It iterates over both train_loader for given batches.
Note: - It does not to loss.backward().
Args:
model : A PyTorch Detr Model.
train_loader : Train loader.
device : "cuda" or "cpu"
criterion : Loss function to be optimized.
num_batches : (optional) Integer To limit sanity fit over certain batches.
Useful is data is too big even for sanity check.
log_interval : (optional) Defualt 100. Integer to Log after specified batch ids in every batch.
fp16: : (optional) If True uses PyTorch native mixed precision Training.
"""

model = model.to(device)
criterion = criterion.to(device)
train_sanity_start = time.time()
model.train()

last_idx = len(train_loader) - 1
criterion.train()
cnt = 0
if fp16 is True:
scaler = amp.GradScaler()

for batch_idx, (inputs, targets) in enumerate(train_loader):
last_batch = batch_idx == last_idx
images = list(image.to(device) for image in inputs)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

if fp16 is True:
with amp.autocast():
outputs = model(images)
loss_dict = criterion(outputs, targets)
weight_dict = criterion.weight_dict
loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)

else:
outputs = model(images)
loss_dict = criterion(outputs, targets)
weight_dict = criterion.weight_dict
loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)

cnt += 1
if last_batch or batch_idx % log_interval == 0:
print(f"Train sanity check passed for batch till {batch_idx} batches")

if num_batches is not None:
if cnt >= num_batches:
print(f"Done till {num_batches} train batches")
print("All specified batches done")
train_sanity_end = time.time()
print(f"Train sanity fit check passed in time {train_sanity_end-train_sanity_start}")
return True

train_sanity_end = time.time()

print("All specified batches done")
print(f"Train sanity fit check passed in time {train_sanity_end-train_sanity_start}")

return True


def val_sanity_fit(model: nn.Module, val_loader, criterion, device,
num_batches: int = None, log_interval: int = 100):

"""
Performs Sanity fit over valid loader.
Use this to dummy check your val_step function. It does not calculate metrics, timing, or does checkpointing.
It iterates over both train_loader and val_loader for given batches.
Note: - It does not to loss.backward().
Args:
model : A PyTorch Detr Model.
val_loader : Validation loader.
criterion : Loss function to be optimized.
device : "cuda" or "cpu"
num_batches : (optional) Integer To limit sanity fit over certain batches.
Useful is data is too big even for sanity check.
log_interval : (optional) Defualt 100. Integer to Log after specified batch ids in every batch.
"""

model = model.to(device)
criterion = criterion.to(device)
train_sanity_start = time.time()
model.eval()

last_idx = len(val_loader) - 1
criterion.eval()
cnt = 0

for batch_idx, (inputs, targets) in enumerate(val_loader):
last_batch = batch_idx == last_idx
images = list(image.to(device) for image in inputs)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
outputs = model(images)
loss_dict = criterion(outputs, targets)
weight_dict = criterion.weight_dict
loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)

cnt += 1
if last_batch or batch_idx % log_interval == 0:
print(f"Train sanity check passed for batch till {batch_idx} batches")

if num_batches is not None:
if cnt >= num_batches:
print(f"Done till {num_batches} train batches")
print("All specified batches done")
train_sanity_end = time.time()
print(f"Train sanity fit check passed in time {train_sanity_end-train_sanity_start}")
return True

train_sanity_end = time.time()

print("All specified batches done")
print(f"Train sanity fit check passed in time {train_sanity_end-train_sanity_start}")

return True


def sanity_fit(model: nn.Module, train_loader, val_loader, criterion, device: str,
num_batches: int = None, log_interval: int = 100, fp16: bool = False,):

"""
Performs Sanity fit over train loader and valid loader.
Use this to dummy check your fit function. It does not calculate metrics, timing, or does checkpointing.
It iterates over both train_loader and val_loader for given batches.
Note: - It does not to loss.backward().
Args:
model : A PyTorch Detr Model.
train_loader : Training loader.
val_loader : Validation loader.
criterion : Loss function to be optimized.
device : "cuda" or "cpu"
num_batches : (optional) Integer To limit sanity fit over certain batches.
Useful is data is too big even for sanity check.
log_interval : (optional) Defualt 100. Integer to Log after specified batch ids in every batch.
"""

def val_sanity_fit():
pass
sanity_train = train_sanity_fit(model, train_loader, criterion, device, num_batches, log_interval, fp16)

sanity_val = val_sanity_fit(model, val_loader, criterion, device, num_batches, log_interval)

def sanity_fit():
pass
return True
6 changes: 3 additions & 3 deletions quickvision/models/detection/faster_rcnn/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ def train_sanity_fit(model: nn.Module, train_loader,

"""
Performs Sanity fit over train loader.
Use this to dummy check your fit function. It does not calculate metrics, timing, or does checkpointing.
It iterates over both train_loader and val_loader for given batches.
Use this to dummy check your train_step function. It does not calculate metrics, timing, or does checkpointing.
It iterates over both train_loader for given batches.
Note: - It does not to loss.backward().
Args:
model : A pytorch Faster RCNN Model.
Expand Down Expand Up @@ -292,7 +292,7 @@ def val_sanity_fit(model: nn.Module, val_loader,

"""
Performs Sanity fit over valid loader.
Use this to dummy check your fit function. It does not calculate metrics, timing, or does checkpointing.
Use this to dummy check your val_step function. It does not calculate metrics, timing, or does checkpointing.
It iterates over both train_loader and val_loader for given batches.
Note: - It does not to loss.backward().
Args:
Expand Down
60 changes: 55 additions & 5 deletions tests/test_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,21 +170,71 @@ def test_fit_cuda(self):
self.assertTrue(exp_k in history.keys())

def test_train_sanity_fit(self):
pass
for bbone in some_supported_backbones:
backbone = detr.create_detr_backbone(bbone, pretrained=None)
self.assertTrue(isinstance(backbone, nn.Module))
detr_model = detr.create_vision_detr(num_classes=3, num_queries=5, backbone=backbone)
self.assertTrue(isinstance(detr_model, nn.Module))
matcher = detr_loss.HungarianMatcher()
weight_dict = {"loss_ce": 1, "loss_bbox": 1, "loss_giou": 1}
losses = ["labels", "boxes", "cardinality"]
criterion = detr_loss.SetCriterion(2, matcher, weight_dict, eos_coef=0.5, losses=losses)
ret = detr.train_sanity_fit(detr_model, train_loader, criterion, "cpu", num_batches=4)
self.assertTrue(ret)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_train_sanity_fit_cuda(self):
pass
for bbone in some_supported_backbones:
backbone = detr.create_detr_backbone(bbone, pretrained=None)
self.assertTrue(isinstance(backbone, nn.Module))
detr_model = detr.create_vision_detr(num_classes=3, num_queries=5, backbone=backbone)
self.assertTrue(isinstance(detr_model, nn.Module))
matcher = detr_loss.HungarianMatcher()
weight_dict = {"loss_ce": 1, "loss_bbox": 1, "loss_giou": 1}
losses = ["labels", "boxes", "cardinality"]
criterion = detr_loss.SetCriterion(2, matcher, weight_dict, eos_coef=0.5, losses=losses)
ret = detr.train_sanity_fit(detr_model, train_loader, criterion, device="cuda", num_batches=4, fp16=True)
self.assertTrue(ret)

def test_val_sanity_fit(self):
pass
for bbone in some_supported_backbones:
backbone = detr.create_detr_backbone(bbone, pretrained=None)
self.assertTrue(isinstance(backbone, nn.Module))
detr_model = detr.create_vision_detr(num_classes=3, num_queries=5, backbone=backbone)
self.assertTrue(isinstance(detr_model, nn.Module))
matcher = detr_loss.HungarianMatcher()
weight_dict = {"loss_ce": 1, "loss_bbox": 1, "loss_giou": 1}
losses = ["labels", "boxes", "cardinality"]
criterion = detr_loss.SetCriterion(2, matcher, weight_dict, eos_coef=0.5, losses=losses)
ret = detr.val_sanity_fit(detr_model, val_loader, criterion, "cpu", num_batches=4)
self.assertTrue(ret)

def test_sanity_fit(self):
pass
for bbone in some_supported_backbones:
backbone = detr.create_detr_backbone(bbone, pretrained=None)
self.assertTrue(isinstance(backbone, nn.Module))
detr_model = detr.create_vision_detr(num_classes=3, num_queries=5, backbone=backbone)
self.assertTrue(isinstance(detr_model, nn.Module))
matcher = detr_loss.HungarianMatcher()
weight_dict = {"loss_ce": 1, "loss_bbox": 1, "loss_giou": 1}
losses = ["labels", "boxes", "cardinality"]
criterion = detr_loss.SetCriterion(2, matcher, weight_dict, eos_coef=0.5, losses=losses)
ret = detr.sanity_fit(detr_model, train_loader, val_loader, criterion, "cpu", num_batches=4)
self.assertTrue(ret)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_sanity_fit_cuda(self):
pass
for bbone in some_supported_backbones:
backbone = detr.create_detr_backbone(bbone, pretrained=None)
self.assertTrue(isinstance(backbone, nn.Module))
detr_model = detr.create_vision_detr(num_classes=3, num_queries=5, backbone=backbone)
self.assertTrue(isinstance(detr_model, nn.Module))
matcher = detr_loss.HungarianMatcher()
weight_dict = {"loss_ce": 1, "loss_bbox": 1, "loss_giou": 1}
losses = ["labels", "boxes", "cardinality"]
criterion = detr_loss.SetCriterion(2, matcher, weight_dict, eos_coef=0.5, losses=losses)
ret = detr.sanity_fit(detr_model, train_loader, val_loader, criterion, "cuda", num_batches=4)
self.assertTrue(ret)


class LightningTester(unittest.TestCase):
Expand Down

0 comments on commit 7f61612

Please sign in to comment.