Description
Problem Context
The function schema aten::Int.Tensor(Tensor a) -> int
is a known problematic case for Torch-TensorRT (see #513). This issue arises from the fact that once an integer becomes a TensorRT tensor, we can no longer extract its contained data. An example graph which has this issue is shown below:
%37436 : int = aten::size(%t.1, %25)
%element.2 : Tensor = prim::NumToTensor(%37436)
%result.2 : Tensor = aten::mul(%element.2, %20)
%37445 : Tensor = aten::mul(%result.2, %element368.1)
%37446 : Tensor = aten::mul(%37445, %element369.1)
%37447 : int = aten::Int(%37446)
In the above, none of the intermediate aten::mul
operations need to be operating on Tensor inputs, since all are simply multiplying single-element Tensors. We do already have lowering passes which replace generic cases of this sort, however catching more of these scenarios would be helpful for performance. See:
Proposed Solution
The proposed solution to this issue is to add a new lowering pass which can resolve cases like the above, by detecting operators like aten::mul
or aten::floor_divide
, which are operating on single-element Tensors. More specifically, if both inputs to aten::mul
are any of the following:
prim::NumToTensor
outputprim::Constant
constructing a single-element Tensor- A
ScalarType
Then that aten::mul
can be replaced by a new aten::mul
which takes as input the original integer arguments, and outputs an integer. For example:
%20 : Tensor = prim::Constant[value={1}]()
%37436 : int = aten::size(%t.1, %25)
%element.2 : Tensor = prim::NumToTensor(%37436)
%result.2 : Tensor = aten::mul(%element.2, %20)
%37445 : Tensor = aten::mul(%result.2, %element368.1)
%37447 : int = aten::Int(%37446)
##### REPLACED WITH #####
%20 : Tensor = prim::Constant[value={1}]()
%37436 : int = aten::size(%t.1, %25)
%result.2 : int = aten::mul(%37436, 1)
...
Additional Context
Relates to #1836 - first step in developing solution.