|
| 1 | +import torch |
| 2 | + |
| 3 | +from trtorch import _types |
| 4 | +import logging |
| 5 | +import trtorch._C |
| 6 | + |
| 7 | +import warnings |
| 8 | + |
| 9 | + |
| 10 | +class Device(object): |
| 11 | + """ |
| 12 | + Defines a device that can be used to specify target devices for engines |
| 13 | +
|
| 14 | + Attributes: |
| 15 | + device_type (trtorch.DeviceType): Target device type (GPU or DLA). Set implicitly based on if dla_core is specified. |
| 16 | + gpu_id (int): Device ID for target GPU |
| 17 | + dla_core (int): Core ID for target DLA core |
| 18 | + allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed |
| 19 | + """ |
| 20 | + |
| 21 | + device_type = None |
| 22 | + gpu_id = -1 |
| 23 | + dla_core = -1 |
| 24 | + allow_gpu_fallback = False |
| 25 | + |
| 26 | + def __init__(self, *args, **kwargs): |
| 27 | + """ __init__ Method for trtorch.Device |
| 28 | +
|
| 29 | + Device accepts one of a few construction patterns |
| 30 | +
|
| 31 | + Args: |
| 32 | + spec (str): String with device spec e.g. "dla:0" for dla, core_id 0 |
| 33 | +
|
| 34 | + Keyword Arguments: |
| 35 | + gpu_id (int): ID of target GPU (will get overrided if dla_core is specified to the GPU managing DLA). If specified, no positional arguments should be provided |
| 36 | + dla_core (int): ID of target DLA core. If specified, no positional arguments should be provided. |
| 37 | + allow_gpu_fallback (bool): Allow TensorRT to schedule operations on GPU if they are not supported on DLA (ignored if device type is not DLA) |
| 38 | +
|
| 39 | + Examples: |
| 40 | + - Device("gpu:1") |
| 41 | + - Device("cuda:1") |
| 42 | + - Device("dla:0", allow_gpu_fallback=True) |
| 43 | + - Device(gpu_id=0, dla_core=0, allow_gpu_fallback=True) |
| 44 | + - Device(dla_core=0, allow_gpu_fallback=True) |
| 45 | + - Device(gpu_id=1) |
| 46 | + """ |
| 47 | + if len(args) == 1: |
| 48 | + if not isinstance(args[0], str): |
| 49 | + raise TypeError("When specifying Device through positional argument, argument must be str") |
| 50 | + else: |
| 51 | + (self.device_type, id) = Device._parse_device_str(args[0]) |
| 52 | + if self.device_type == _types.DeviceType.GPU: |
| 53 | + self.gpu_id = id |
| 54 | + else: |
| 55 | + self.dla_core = id |
| 56 | + self.gpu_id = 0 |
| 57 | + logging.log(logging.log.Level.Warning, |
| 58 | + "Setting GPU id to 0 for device because device 0 manages DLA on Xavier") |
| 59 | + |
| 60 | + elif len(args) == 0: |
| 61 | + if not "gpu_id" in kwargs or not "dla_core" in kwargs: |
| 62 | + if "dla_core" in kwargs: |
| 63 | + self.device_type = _types.DeviceType.DLA |
| 64 | + self.dla_core = kwargs["dla_core"] |
| 65 | + if "gpu_id" in kwargs: |
| 66 | + self.gpu_id = kwargs["gpu_id"] |
| 67 | + else: |
| 68 | + self.gpu_id = 0 |
| 69 | + logging.log(logging.log.Level.Warning, |
| 70 | + "Setting GPU id to 0 for device because device 0 manages DLA on Xavier") |
| 71 | + else: |
| 72 | + self.gpu_id = kwargs["gpu_id"] |
| 73 | + self.device_type == _types.DeviceType.GPU |
| 74 | + |
| 75 | + else: |
| 76 | + raise ValueError( |
| 77 | + "Unexpected number of positional arguments for class Device \n Found {} arguments, expected either zero or a single positional arguments" |
| 78 | + .format(len(args))) |
| 79 | + |
| 80 | + if "allow_gpu_fallback" in kwargs: |
| 81 | + if not isinstance(kwargs["allow_gpu_fallback"], bool): |
| 82 | + raise TypeError("allow_gpu_fallback must be a bool") |
| 83 | + |
| 84 | + def __str__(self) -> str: |
| 85 | + return "Device(type={}, gpu_id={}".format(self.device_type, self.gpu_id) \ |
| 86 | + + ")" if self.device_type == _types.DeviceType.GPU else ", dla_core={}, allow_gpu_fallback={}".format(self.dla_core, self.allow_gpu_fallback) |
| 87 | + |
| 88 | + def _to_internal(self) -> trtorch._C.Device: |
| 89 | + internal_dev = trtorch._C.Device() |
| 90 | + internal_dev.device_type = self.device_type |
| 91 | + internal_dev.gpu_id = self.gpu_id |
| 92 | + internal_dev.dla_core = self.dla_core |
| 93 | + internal_dev.allow_gpu_fallback = self.allow_gpu_fallback |
| 94 | + return internal_dev |
| 95 | + |
| 96 | + @classmethod |
| 97 | + def _from_torch_device(cls, torch_dev: torch.device): |
| 98 | + if torch_dev.type != 'cuda': |
| 99 | + raise ValueError("Torch Device specs must have type \"cuda\"") |
| 100 | + gpu_id = torch_dev.index |
| 101 | + return cls(gpu_id=gpu_id) |
| 102 | + |
| 103 | + @staticmethod |
| 104 | + def _parse_device_str(s): |
| 105 | + s = s.lower() |
| 106 | + spec = s.split(':') |
| 107 | + if spec[0] == "gpu" or spec[0] == "cuda": |
| 108 | + return (_types.DeviceType.GPU, int(spec[1])) |
| 109 | + elif spec[0] == "dla": |
| 110 | + return (_types.DeviceType.DLA, int(spec[1])) |
0 commit comments