From c2c2e39d604c1eda63918e8c9a8ca09f48548f42 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 10 Aug 2023 15:36:48 -0700 Subject: [PATCH 1/3] feat: Add ExportedProgram as an IR Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/_compile.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index fc1c0d30d8..4599147055 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -15,6 +15,7 @@ from torch_tensorrt.fx.utils import LowerPrecision from torch_tensorrt.ts._compiler import compile as torchscript_compile from typing_extensions import TypeGuard +from torch._export import ExportedProgram def _non_fx_input_interface( @@ -36,6 +37,7 @@ class _IRType(Enum): fx = 1 dynamo = 2 torch_compile = 3 + exported_program = 4 class _ModuleType(Enum): @@ -44,6 +46,7 @@ class _ModuleType(Enum): nn = 0 ts = 1 fx = 2 + ep = 3 def _parse_module_type(module: Any) -> _ModuleType: @@ -54,6 +57,8 @@ def _parse_module_type(module: Any) -> _ModuleType: return _ModuleType.ts elif isinstance(module, torch.fx.GraphModule): return _ModuleType.fx + elif isinstance(module, ExportedProgram): + return _ModuleType.ep elif isinstance(module, torch.nn.Module): return _ModuleType.nn else: @@ -63,11 +68,13 @@ def _parse_module_type(module: Any) -> _ModuleType: def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType: module_is_tsable = any(module_type == t for t in [_ModuleType.nn, _ModuleType.ts]) module_is_fxable = any(module_type == t for t in [_ModuleType.nn, _ModuleType.fx]) + module_is_exportable = module_type == _ModuleType.ep ir_targets_torchscript = any(ir == opt for opt in ["torchscript", "ts"]) ir_targets_fx = ir == "fx" ir_targets_dynamo = ir == "dynamo" ir_targets_torch_compile = ir == "torch_compile" + ir_targets_ep = ir == "exported_program" if module_is_tsable and ir_targets_torchscript: return _IRType.ts @@ -75,6 +82,8 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType: return _IRType.fx elif module_is_fxable and ir_targets_dynamo: return _IRType.dynamo + elif module_is_fxable and ir_targets_ep: + return _IRType.dynamo elif module_is_fxable and ir_targets_torch_compile: return _IRType.torch_compile else: @@ -91,6 +100,8 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType: "Input graph is a Torchscript module but the ir provided is default (dynamo). Please set ir=torchscript to suppress the warning. Compiling the module with ir=torchscript", ) return _IRType.ts + elif module_is_exportable: + raise ValueError("Input graph is an ExportedProgram which is not currently supported. Please provide torch.nn.Module or torch.fx.GraphModule as input.") else: raise ValueError("Module was provided in an unsupported format") else: From f9c836cacdf3be0dfb7ef7a5d2d834d2e240d40a Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 10 Aug 2023 17:21:00 -0700 Subject: [PATCH 2/3] chore: linter fixes --- py/torch_tensorrt/_compile.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 4599147055..9295cd6576 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -101,7 +101,9 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType: ) return _IRType.ts elif module_is_exportable: - raise ValueError("Input graph is an ExportedProgram which is not currently supported. Please provide torch.nn.Module or torch.fx.GraphModule as input.") + raise ValueError( + "Input graph is an ExportedProgram which is not currently supported. Please provide torch.nn.Module or torch.fx.GraphModule as input." + ) else: raise ValueError("Module was provided in an unsupported format") else: From ead2aa681625420dd2ba300f785a0c0362e0c629 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 11 Aug 2023 16:47:46 -0700 Subject: [PATCH 3/3] chore: modify behavior of exported_program Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/_compile.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 9295cd6576..f64cc82e75 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -6,6 +6,7 @@ import torch import torch.fx import torch_tensorrt.ts +from torch._export import ExportedProgram from torch_tensorrt import logging from torch_tensorrt._enums import dtype from torch_tensorrt._Input import Input @@ -15,7 +16,6 @@ from torch_tensorrt.fx.utils import LowerPrecision from torch_tensorrt.ts._compiler import compile as torchscript_compile from typing_extensions import TypeGuard -from torch._export import ExportedProgram def _non_fx_input_interface( @@ -74,7 +74,6 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType: ir_targets_fx = ir == "fx" ir_targets_dynamo = ir == "dynamo" ir_targets_torch_compile = ir == "torch_compile" - ir_targets_ep = ir == "exported_program" if module_is_tsable and ir_targets_torchscript: return _IRType.ts @@ -82,8 +81,6 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType: return _IRType.fx elif module_is_fxable and ir_targets_dynamo: return _IRType.dynamo - elif module_is_fxable and ir_targets_ep: - return _IRType.dynamo elif module_is_fxable and ir_targets_torch_compile: return _IRType.torch_compile else: @@ -103,9 +100,13 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType: elif module_is_exportable: raise ValueError( "Input graph is an ExportedProgram which is not currently supported. Please provide torch.nn.Module or torch.fx.GraphModule as input." - ) + ) else: raise ValueError("Module was provided in an unsupported format") + elif ir == "exported_program": + raise ValueError( + "ir=exported_program is not currently supported. Supported ir options : ts|fx|dynamo" + ) else: raise ValueError("Unknown ir was requested")