diff --git a/utils/torch_utils.py b/utils/torch_utils.py
index 9991e5ec87d8..5074fa95ae4b 100644
--- a/utils/torch_utils.py
+++ b/utils/torch_utils.py
@@ -72,11 +72,12 @@ def select_device(device='', batch_size=None):
 
     cuda = not cpu and torch.cuda.is_available()
     if cuda:
-        n = torch.cuda.device_count()
-        if n > 1 and batch_size:  # check that batch_size is compatible with device_count
+        devices = device.split(',') if device else range(torch.cuda.device_count())  # i.e. 0,1,6,7
+        n = len(devices)  # device count
+        if n > 1 and batch_size:  # check batch_size is divisible by device_count
             assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
         space = ' ' * len(s)
-        for i, d in enumerate(device.split(',') if device else range(n)):
+        for i, d in enumerate(devices):
             p = torch.cuda.get_device_properties(i)
             s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n"  # bytes to MB
     else: