-
Notifications
You must be signed in to change notification settings - Fork 369
Description
Feature Context
Models which are fully supported in TRT, except for their input type being a collection should be able to be fully-compiled in Torch-TRT. Considering that Torch-executed list packing and list unpacking code is already being inserted (by necessity) even when models are fully supported, there should not be a need to disable full compilation when providing complex input types. Additionally, operators including prim::ListUnpack should not be added to torch_executed_ops automatically upon using input_signature, as they are currently, since evaluators for them exist.
Desired Solution
The preferred solution is to remove the requirement for require_full_compilation=False when using input_signature and to remove the requirement that collection-based operators be executed in fallback:
TensorRT/py/torch_tensorrt/ts/_compile_spec.py
Lines 259 to 300 in 835abf0
| elif compile_spec["input_signature"] is not None: | |
| log( | |
| Level.Warning, | |
| "Input signature parsing is an experimental feature, behavior and APIs may change", | |
| ) | |
| signature = _parse_input_signature(compile_spec["input_signature"]) | |
| info.input_signature = _C.InputSignature(signature) # py_object | |
| if not compile_spec["torch_fallback"]["enabled"]: | |
| raise ValueError( | |
| "Grouped inputs currently requires partial compilation to be enabled, this restriction will be relaxed in a future release" | |
| ) | |
| log( | |
| Level.Debug, | |
| "Grouped inputs currently requires additional settings to enable the feature", | |
| ) | |
| log( | |
| Level.Debug, | |
| """Adding the following ops to torch_executed_ops: | |
| - aten::__getitem__ | |
| - prim::ListConstruct | |
| - prim::ListUnpack | |
| - prim::TupleIndex | |
| - prim::TupleConstruct | |
| - prim::TupleUnpack | |
| """, | |
| ) | |
| compile_spec["torch_fallback"]["forced_fallback_ops"].append( | |
| "aten::__getitem__" | |
| ) | |
| compile_spec["torch_fallback"]["forced_fallback_ops"].append( | |
| "prim::ListConstruct" | |
| ) | |
| compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::ListUnpack") | |
| compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::TupleIndex") | |
| compile_spec["torch_fallback"]["forced_fallback_ops"].append( | |
| "prim::TupleConstruct" | |
| ) | |
| compile_spec["torch_fallback"]["forced_fallback_ops"].append( | |
| "prim::TupleUnpack" | |
| ) |
This would require modification of the C++
core code as well, to ensure that relaxing this requirement will not cause further issues with the existing compilation phases.
Additional Context
A proof-of-concept for this feature already exists in PR #1599, which could be used as a template to enable full-compilation functionality for collection inputs as well. This would complete the plan for Collection IO as discussed in #629 (comment).