diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 267b398c6b..23f5fd57c6 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -7,6 +7,7 @@ import torch import torch.fx import torch_tensorrt.ts +from torch._export import ExportedProgram from torch_tensorrt._enums import dtype from torch_tensorrt._Input import Input from torch_tensorrt.dynamo.compile import compile as dynamo_compile @@ -43,6 +44,7 @@ class _IRType(Enum): fx = 1 dynamo = 2 torch_compile = 3 + exported_program = 4 class _ModuleType(Enum): @@ -51,6 +53,7 @@ class _ModuleType(Enum): nn = 0 ts = 1 fx = 2 + ep = 3 def _parse_module_type(module: Any) -> _ModuleType: @@ -61,6 +64,8 @@ def _parse_module_type(module: Any) -> _ModuleType: return _ModuleType.ts elif isinstance(module, torch.fx.GraphModule): return _ModuleType.fx + elif isinstance(module, ExportedProgram): + return _ModuleType.ep elif isinstance(module, torch.nn.Module): return _ModuleType.nn else: @@ -70,6 +75,7 @@ def _parse_module_type(module: Any) -> _ModuleType: def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType: module_is_tsable = any(module_type == t for t in [_ModuleType.nn, _ModuleType.ts]) module_is_fxable = any(module_type == t for t in [_ModuleType.nn, _ModuleType.fx]) + module_is_exportable = module_type == _ModuleType.ep ir_targets_torchscript = any(ir == opt for opt in ["torchscript", "ts"]) ir_targets_fx = ir == "fx" @@ -95,8 +101,16 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType: "Input graph is a Torchscript module but the ir provided is default (dynamo). Please set ir=torchscript to suppress the warning. Compiling the module with ir=torchscript" ) return _IRType.ts + elif module_is_exportable: + raise ValueError( + "Input graph is an ExportedProgram which is not currently supported. Please provide torch.nn.Module or torch.fx.GraphModule as input." + ) else: raise ValueError("Module was provided in an unsupported format") + elif ir == "exported_program": + raise ValueError( + "ir=exported_program is not currently supported. Supported ir options : ts|fx|dynamo" + ) else: raise ValueError("Unknown ir was requested")