Skip to content

✨[Feature] Add lowering pass cases to avoid aten::Int.Tensor calls #1880

Closed
@gs-olive

Description

@gs-olive

Problem Context

The function schema aten::Int.Tensor(Tensor a) -> int is a known problematic case for Torch-TensorRT (see #513). This issue arises from the fact that once an integer becomes a TensorRT tensor, we can no longer extract its contained data. An example graph which has this issue is shown below:

  %37436 : int = aten::size(%t.1, %25)  
  %element.2 : Tensor = prim::NumToTensor(%37436)
  %result.2 : Tensor = aten::mul(%element.2, %20)
  %37445 : Tensor = aten::mul(%result.2, %element368.1)
  %37446 : Tensor = aten::mul(%37445, %element369.1)
  %37447 : int = aten::Int(%37446)

In the above, none of the intermediate aten::mul operations need to be operating on Tensor inputs, since all are simply multiplying single-element Tensors. We do already have lowering passes which replace generic cases of this sort, however catching more of these scenarios would be helpful for performance. See:

void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {

Proposed Solution

The proposed solution to this issue is to add a new lowering pass which can resolve cases like the above, by detecting operators like aten::mul or aten::floor_divide, which are operating on single-element Tensors. More specifically, if both inputs to aten::mul are any of the following:

  • prim::NumToTensor output
  • prim::Constant constructing a single-element Tensor
  • A ScalarType

Then that aten::mul can be replaced by a new aten::mul which takes as input the original integer arguments, and outputs an integer. For example:

  %20 : Tensor = prim::Constant[value={1}]()
  %37436 : int = aten::size(%t.1, %25)  
  %element.2 : Tensor = prim::NumToTensor(%37436)
  %result.2 : Tensor = aten::mul(%element.2, %20)
  %37445 : Tensor = aten::mul(%result.2, %element368.1)
  %37447 : int = aten::Int(%37446)

##### REPLACED WITH #####

  %20 : Tensor = prim::Constant[value={1}]()
  %37436 : int = aten::size(%t.1, %25)  
  %result.2 : int = aten::mul(%37436, 1)
   ...

Additional Context

Relates to #1836 - first step in developing solution.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions