diff --git a/validate.py b/validate.py
index cb71a0624e..86327731f3 100755
--- a/validate.py
+++ b/validate.py
@@ -314,11 +314,15 @@ def validate(args):
     model.eval()
     with torch.no_grad():
         # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
-        input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device)
-        if args.channels_last:
-            input = input.contiguous(memory_format=torch.channels_last)
-        with amp_autocast():
-            model(input)
+        inputs = [torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device)]
+        last_batch_size = len(dataset) % args.batch_size
+        if last_batch_size:
+            inputs.append(torch.randn((last_batch_size,) + tuple(data_config['input_size'])).to(device))
+        for inp in inputs:
+            if args.channels_last:
+                inp = inp.contiguous(memory_format=torch.channels_last)
+            with amp_autocast():
+                model(inp)
 
         end = time.time()
         for batch_idx, (input, target) in enumerate(loader):