Skip to content

Commit d02f71a

Browse files
author
Nathan Cassereau
committed
Better device_type
1 parent c507d3b commit d02f71a

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

ot/backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1688,7 +1688,7 @@ def bitsize(self, type_as):
16881688
return torch.finfo(type_as.dtype).bits
16891689

16901690
def device_type(self, type_as):
1691-
return "CPU" if "cpu" in str(type_as.device) else "GPU"
1691+
return type_as.device.type.replace("cuda", "gpu").upper()
16921692

16931693
def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
16941694
results = dict()
@@ -2337,7 +2337,7 @@ def bitsize(self, type_as):
23372337
return type_as.dtype.size * 8
23382338

23392339
def device_type(self, type_as):
2340-
return "CPU" if "CPU" in type_as.device else "GPU"
2340+
return self.dtype_device(type_as)[1].split(":")[0]
23412341

23422342
def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
23432343
results = dict()

0 commit comments

Comments
 (0)