| 
 | 1 | +from typing import List, Dict, Any  | 
 | 2 | +import torch  | 
 | 3 | +import tensorrt as trt  | 
 | 4 | +import trtorch._C  | 
 | 5 | +from trtorch import types  | 
 | 6 | +from .version import __version__  | 
 | 7 | + | 
 | 8 | +def _supported_input_size_type(input_size: Any) -> bool:  | 
 | 9 | +    if isinstance(input_size, torch.Size):  | 
 | 10 | +        return True  | 
 | 11 | +    elif isinstance(input_size, tuple):  | 
 | 12 | +        return True  | 
 | 13 | +    elif isinstance(input_size, list):  | 
 | 14 | +        return True  | 
 | 15 | +    else:  | 
 | 16 | +        raise TypeError("Input sizes for inputs are required to be a List, tuple or torch.Size or a Dict of three sizes (min, opt, max), found type: " + str(type(input_size)))  | 
 | 17 | + | 
 | 18 | +def _parse_input_sizes(input_sizes: List) -> List:  | 
 | 19 | + | 
 | 20 | +    if any (not isinstance(i, dict) and not _supported_input_size_type(i) for i in input_sizes):  | 
 | 21 | +        raise KeyError("An input size must either be a static size or a range of three sizes (min, opt, max) as Dict")  | 
 | 22 | + | 
 | 23 | +    parsed_input_sizes = []  | 
 | 24 | +    for i in input_sizes:  | 
 | 25 | +        if isinstance(i, dict):  | 
 | 26 | +            if all (k in i for k in ["min", "opt", "min"]):  | 
 | 27 | +                in_range = trtorch._C.InputRange()  | 
 | 28 | +                in_range.min = i["min"]  | 
 | 29 | +                in_range.opt = i["opt"]  | 
 | 30 | +                in_range.max = i["max"]  | 
 | 31 | + | 
 | 32 | +                parsed_input_sizes.append(in_range.to_internal_input_range())  | 
 | 33 | + | 
 | 34 | +            elif "opt" in i:  | 
 | 35 | +                in_range = trtorch._C.InputRange()  | 
 | 36 | +                in_range.min = i["opt"]  | 
 | 37 | +                in_range.opt = i["opt"]  | 
 | 38 | +                in_range.max = i["opt"]  | 
 | 39 | + | 
 | 40 | +                parsed_input_sizes.append(in_range.to_internal_input_range())  | 
 | 41 | + | 
 | 42 | +            else:  | 
 | 43 | +                raise KeyError("An input size must either be a static size or a range of three sizes (min, opt, max) as Dict")  | 
 | 44 | + | 
 | 45 | +        elif isinstance(i, list):  | 
 | 46 | +            in_range = trtorch._C.InputRange()  | 
 | 47 | +            in_range.min = i  | 
 | 48 | +            in_range.opt = i  | 
 | 49 | +            in_range.max = i  | 
 | 50 | + | 
 | 51 | +            parsed_input_sizes.append(in_range.to_internal_input_range())  | 
 | 52 | + | 
 | 53 | +    return parsed_input_sizes  | 
 | 54 | + | 
 | 55 | +def _parse_op_precision(precision: Any) -> types.dtype:  | 
 | 56 | +    if isinstance(precision, torch.dtype):  | 
 | 57 | +        if precision == torch.int8:  | 
 | 58 | +            return types.dtype.int8  | 
 | 59 | +        elif precision == torch.half:  | 
 | 60 | +            return types.dtype.half  | 
 | 61 | +        elif precision == torch.float:  | 
 | 62 | +            return types.dtype.float  | 
 | 63 | +        else:  | 
 | 64 | +            raise TypeError("Provided an unsupported dtype as operating precision (support: int8, half, float), got: " + str(precision))  | 
 | 65 | + | 
 | 66 | +    elif isinstance(precision, types.DataTypes):  | 
 | 67 | +        return precision  | 
 | 68 | + | 
 | 69 | +    else:  | 
 | 70 | +        raise TypeError("Op precision type needs to be specified with a torch.dtype or a trtorch.dtype, got: " + str(type(precision)))  | 
 | 71 | + | 
 | 72 | +def _parse_device_type(device: Any) -> types.DeviceType:  | 
 | 73 | +    if isinstance(device, torch.device):  | 
 | 74 | +        if torch.device.type == 'cuda':  | 
 | 75 | +            return types.DeviceType.gpu  | 
 | 76 | +        else:  | 
 | 77 | +            raise TypeError("Valid device choices are GPU (and DLA if on Jetson platforms) however got device type" + str(device.type))  | 
 | 78 | + | 
 | 79 | +    elif isinstance(device, types.DeviceType):  | 
 | 80 | +        return device  | 
 | 81 | + | 
 | 82 | +    else:  | 
 | 83 | +        raise TypeError("Device specification must be of type torch.device or trtorch.DeviceType, but got: " + str(type(device)))  | 
 | 84 | + | 
 | 85 | +def _parse_extra_info(extra_info: Dict[str, Any]) -> trtorch._C._ExtraInfo:  | 
 | 86 | +    info = trtorch._C._ExtraInfo()  | 
 | 87 | +    if "input_shapes" not in extra_info and not isinstance(extra_info["input_shapes"], list):  | 
 | 88 | +        raise KeyError("Input shapes for inputs are required as a List, provided as either a static sizes or a range of three sizes (min, opt, max) as Dict")  | 
 | 89 | + | 
 | 90 | +    info.input_ranges = _parse_input_sizes(extra_info["input_shapes"])  | 
 | 91 | + | 
 | 92 | +    if "op_precision" in extra_info:  | 
 | 93 | +        info.op_precision = _parse_op_precision(extra_info["op_precision"])  | 
 | 94 | + | 
 | 95 | +    if "refit" in extra_info:  | 
 | 96 | +        assert isinstance(extra_info["refit"], bool)  | 
 | 97 | +        info.refit = extra_info["refit"]  | 
 | 98 | + | 
 | 99 | +    if "debug" in extra_info:  | 
 | 100 | +        assert isinstance(extra_info["debug"], bool)  | 
 | 101 | +        info.debug = extra_info["debug"]  | 
 | 102 | + | 
 | 103 | +    if "strict_types" in extra_info:  | 
 | 104 | +        assert isinstance(extra_info["strict_types"], bool)  | 
 | 105 | +        info.strict_types = extra_info["strict_types"]  | 
 | 106 | + | 
 | 107 | +    if "allow_gpu_fallback" in extra_info:  | 
 | 108 | +        assert isinstance(extra_info["allow_gpu_fallback"], bool)  | 
 | 109 | +        info.allow_gpu_fallback = extra_info["allow_gpu_fallback"]  | 
 | 110 | + | 
 | 111 | +    if "device" in extra_info:  | 
 | 112 | +        info.device = _parse_device_type(extra_info["device"])  | 
 | 113 | + | 
 | 114 | +    if "capability" in extra_info:  | 
 | 115 | +        assert isinstance(extra_info["capability"], type.EngineCapability)  | 
 | 116 | +        info.capability = extra_info["capability"]  | 
 | 117 | + | 
 | 118 | + | 
 | 119 | +    if "num_min_timing_iters" in extra_info:  | 
 | 120 | +        assert type(extra_info["num_min_timing_iters"]) is int  | 
 | 121 | +        info.num_min_timing_iters = extra_info["num_min_timing_iters"]  | 
 | 122 | + | 
 | 123 | +    if "num_avg_timing_iters" in extra_info:  | 
 | 124 | +        assert type(extra_info["num_avg_timing_iters"]) is int  | 
 | 125 | +        info.num_avg_timing_iters = extra_info["num_avg_timing_iters"]  | 
 | 126 | + | 
 | 127 | +    if "workspace_size" in extra_info:  | 
 | 128 | +        assert type(extra_info["workspace_size"]) is int  | 
 | 129 | +        info.workspace_size = extra_info["workspace_size"]  | 
 | 130 | + | 
 | 131 | +    if "max_batch_size" in extra_info:  | 
 | 132 | +        assert type(extra_info["max_batch_size"]) is int  | 
 | 133 | +        info.max_batch_size = extra_info["max_batch_size"]  | 
 | 134 | + | 
 | 135 | +    return info  | 
 | 136 | + | 
 | 137 | +def compile_module(module: torch.jit.ScriptModule, extra_info: Any) -> torch.jit.ScriptModule:  | 
 | 138 | +    return module  | 
 | 139 | + | 
 | 140 | +def convert_graph_to_trt_engine(module: torch.jit.ScriptModule, method_name: str, extra_info: Any) -> str:  | 
 | 141 | +    return trtorch._C._convert_graph_to_trt_engine(module._c, method_name, _parse_extra_info(extra_info))  | 
 | 142 | + | 
 | 143 | +def check_method_op_support(module: torch.jit.ScriptModule, method_name: str) -> bool:  | 
 | 144 | +    return trtorch._C._check_method_op_support(module._c, method_name)  | 
 | 145 | + | 
 | 146 | +def dump_build_info():  | 
 | 147 | +    print(get_build_info())  | 
 | 148 | + | 
 | 149 | +def get_build_info() -> str:  | 
 | 150 | +    build_info = trtorch._C._get_build_info()  | 
 | 151 | +    build_info = "TRTorch Version: " + str(__version__) + '\n' + build_info  | 
 | 152 | +    return build_info  | 
 | 153 | + | 
0 commit comments