Skip to content

Commit 53401dd

Browse files
authored
feat: support more elementwise and unary dynamo converters (#2429)
1 parent cd158b6 commit 53401dd

File tree

14 files changed

+876
-67
lines changed

14 files changed

+876
-67
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 253 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1739,9 +1739,177 @@ def aten_ops_logical_xor(
17391739
)
17401740

17411741

1742+
def bitwise_type_validator(node: Node) -> bool:
1743+
supported_type = [torch.bool, bool]
1744+
1745+
tensor_targets = [
1746+
torch.ops.aten.bitwise_and.Tensor,
1747+
torch.ops.aten.bitwise_or.Tensor,
1748+
torch.ops.aten.bitwise_xor.Tensor,
1749+
]
1750+
scalar_targets = [
1751+
torch.ops.aten.bitwise_and.Scalar,
1752+
torch.ops.aten.bitwise_or.Scalar,
1753+
torch.ops.aten.bitwise_xor.Scalar,
1754+
]
1755+
scalar_tensor_targets = [
1756+
torch.ops.aten.bitwise_and.Scalar_Tensor,
1757+
torch.ops.aten.bitwise_or.Scalar_Tensor,
1758+
torch.ops.aten.bitwise_xor.Scalar_Tensor,
1759+
]
1760+
1761+
if node.target in tensor_targets:
1762+
lhs_val = node.args[0]
1763+
rhs_val = node.args[1]
1764+
lhs_meta = lhs_val.meta.get("tensor_meta")
1765+
rhs_meta = rhs_val.meta.get("tensor_meta")
1766+
if lhs_meta is None or rhs_meta is None:
1767+
return False
1768+
return lhs_meta.dtype in supported_type and rhs_meta.dtype in supported_type
1769+
1770+
elif node.target in scalar_targets:
1771+
lhs_val = node.args[0]
1772+
rhs_val = node.args[1]
1773+
lhs_meta = lhs_val.meta.get("tensor_meta")
1774+
if lhs_meta is None:
1775+
return False
1776+
return lhs_meta.dtype in supported_type and isinstance(rhs_val, bool)
1777+
1778+
elif node.target in scalar_tensor_targets:
1779+
lhs_val = node.args[0]
1780+
rhs_val = node.args[1]
1781+
rhs_meta = rhs_val.meta.get("tensor_meta")
1782+
if rhs_meta is None:
1783+
return False
1784+
return isinstance(lhs_val, bool) and rhs_meta.dtype in supported_type
1785+
1786+
else:
1787+
return False
1788+
1789+
1790+
@dynamo_tensorrt_converter(
1791+
torch.ops.aten.bitwise_and.Tensor, capability_validator=bitwise_type_validator
1792+
)
1793+
@dynamo_tensorrt_converter(
1794+
torch.ops.aten.bitwise_and.Scalar, capability_validator=bitwise_type_validator
1795+
)
1796+
@dynamo_tensorrt_converter(
1797+
torch.ops.aten.bitwise_and.Scalar_Tensor,
1798+
capability_validator=bitwise_type_validator,
1799+
)
1800+
def aten_ops_bitwise_and(
1801+
ctx: ConversionContext,
1802+
target: Target,
1803+
args: Tuple[Argument, ...],
1804+
kwargs: Dict[str, Argument],
1805+
name: str,
1806+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1807+
return impl.elementwise.bitwise_and(
1808+
ctx,
1809+
target,
1810+
SourceIR.ATEN,
1811+
name,
1812+
args[0],
1813+
args[1],
1814+
)
1815+
1816+
1817+
@dynamo_tensorrt_converter(
1818+
torch.ops.aten.bitwise_or.Tensor, capability_validator=bitwise_type_validator
1819+
)
1820+
@dynamo_tensorrt_converter(
1821+
torch.ops.aten.bitwise_or.Scalar, capability_validator=bitwise_type_validator
1822+
)
1823+
@dynamo_tensorrt_converter(
1824+
torch.ops.aten.bitwise_or.Scalar_Tensor, capability_validator=bitwise_type_validator
1825+
)
1826+
def aten_ops_bitwise_or(
1827+
ctx: ConversionContext,
1828+
target: Target,
1829+
args: Tuple[Argument, ...],
1830+
kwargs: Dict[str, Argument],
1831+
name: str,
1832+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1833+
return impl.elementwise.bitwise_or(
1834+
ctx,
1835+
target,
1836+
SourceIR.ATEN,
1837+
name,
1838+
args[0],
1839+
args[1],
1840+
)
1841+
1842+
1843+
@dynamo_tensorrt_converter(
1844+
torch.ops.aten.bitwise_xor.Tensor, capability_validator=bitwise_type_validator
1845+
)
1846+
@dynamo_tensorrt_converter(
1847+
torch.ops.aten.bitwise_xor.Scalar, capability_validator=bitwise_type_validator
1848+
)
1849+
@dynamo_tensorrt_converter(
1850+
torch.ops.aten.bitwise_xor.Scalar_Tensor,
1851+
capability_validator=bitwise_type_validator,
1852+
)
1853+
def aten_ops_bitwise_xor(
1854+
ctx: ConversionContext,
1855+
target: Target,
1856+
args: Tuple[Argument, ...],
1857+
kwargs: Dict[str, Argument],
1858+
name: str,
1859+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1860+
return impl.elementwise.bitwise_xor(
1861+
ctx,
1862+
target,
1863+
SourceIR.ATEN,
1864+
name,
1865+
args[0],
1866+
args[1],
1867+
)
1868+
1869+
1870+
def bitwise_not_type_validator(node: Node) -> bool:
1871+
val = node.args[0]
1872+
val_meta = val.meta.get("tensor_meta")
1873+
1874+
if val_meta is None:
1875+
return False
1876+
1877+
supported_type = [torch.bool, bool]
1878+
return val_meta.dtype in supported_type
1879+
1880+
1881+
@dynamo_tensorrt_converter(
1882+
torch.ops.aten.bitwise_not.default, capability_validator=bitwise_not_type_validator
1883+
)
1884+
@enforce_tensor_types(
1885+
{
1886+
0: (TRTTensor,),
1887+
}
1888+
)
1889+
def aten_ops_bitwise_not(
1890+
ctx: ConversionContext,
1891+
target: Target,
1892+
args: Tuple[Argument, ...],
1893+
kwargs: Dict[str, Argument],
1894+
name: str,
1895+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1896+
return impl.unary.bitwise_not(
1897+
ctx,
1898+
target,
1899+
SourceIR.ATEN,
1900+
name,
1901+
args[0],
1902+
)
1903+
1904+
17421905
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor)
17431906
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar)
1744-
def aten_ops_equal(
1907+
@enforce_tensor_types(
1908+
{
1909+
0: (TRTTensor,),
1910+
}
1911+
)
1912+
def aten_ops_eq(
17451913
ctx: ConversionContext,
17461914
target: Target,
17471915
args: Tuple[Argument, ...],
@@ -1758,9 +1926,38 @@ def aten_ops_equal(
17581926
)
17591927

17601928

1929+
@dynamo_tensorrt_converter(torch.ops.aten.ne.Tensor)
1930+
@dynamo_tensorrt_converter(torch.ops.aten.ne.Scalar)
1931+
@enforce_tensor_types(
1932+
{
1933+
0: (TRTTensor,),
1934+
}
1935+
)
1936+
def aten_ops_ne(
1937+
ctx: ConversionContext,
1938+
target: Target,
1939+
args: Tuple[Argument, ...],
1940+
kwargs: Dict[str, Argument],
1941+
name: str,
1942+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1943+
return impl.elementwise.ne(
1944+
ctx,
1945+
target,
1946+
SourceIR.ATEN,
1947+
name,
1948+
args[0],
1949+
args[1],
1950+
)
1951+
1952+
17611953
@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor)
17621954
@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar)
1763-
def aten_ops_greater(
1955+
@enforce_tensor_types(
1956+
{
1957+
0: (TRTTensor,),
1958+
}
1959+
)
1960+
def aten_ops_gt(
17641961
ctx: ConversionContext,
17651962
target: Target,
17661963
args: Tuple[Argument, ...],
@@ -1777,9 +1974,38 @@ def aten_ops_greater(
17771974
)
17781975

17791976

1977+
@dynamo_tensorrt_converter(torch.ops.aten.ge.Tensor)
1978+
@dynamo_tensorrt_converter(torch.ops.aten.ge.Scalar)
1979+
@enforce_tensor_types(
1980+
{
1981+
0: (TRTTensor,),
1982+
}
1983+
)
1984+
def aten_ops_ge(
1985+
ctx: ConversionContext,
1986+
target: Target,
1987+
args: Tuple[Argument, ...],
1988+
kwargs: Dict[str, Argument],
1989+
name: str,
1990+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1991+
return impl.elementwise.ge(
1992+
ctx,
1993+
target,
1994+
SourceIR.ATEN,
1995+
name,
1996+
args[0],
1997+
args[1],
1998+
)
1999+
2000+
17802001
@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor)
17812002
@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar)
1782-
def aten_ops_less(
2003+
@enforce_tensor_types(
2004+
{
2005+
0: (TRTTensor,),
2006+
}
2007+
)
2008+
def aten_ops_lt(
17832009
ctx: ConversionContext,
17842010
target: Target,
17852011
args: Tuple[Argument, ...],
@@ -1796,6 +2022,30 @@ def aten_ops_less(
17962022
)
17972023

17982024

2025+
@dynamo_tensorrt_converter(torch.ops.aten.le.Tensor)
2026+
@dynamo_tensorrt_converter(torch.ops.aten.le.Scalar)
2027+
@enforce_tensor_types(
2028+
{
2029+
0: (TRTTensor,),
2030+
}
2031+
)
2032+
def aten_ops_le(
2033+
ctx: ConversionContext,
2034+
target: Target,
2035+
args: Tuple[Argument, ...],
2036+
kwargs: Dict[str, Argument],
2037+
name: str,
2038+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2039+
return impl.elementwise.le(
2040+
ctx,
2041+
target,
2042+
SourceIR.ATEN,
2043+
name,
2044+
args[0],
2045+
args[1],
2046+
)
2047+
2048+
17992049
def conv_param_validator(conv_node: Node) -> bool:
18002050
return conv_node.args[7] in ([0], [0, 0], [0, 0, 0])
18012051

py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def convert_binary_elementwise(
5858
source_ir: Optional[SourceIR],
5959
name: str,
6060
op_type: trt.ElementWiseOperation,
61-
lhs_val: Union[int, float, TRTTensor, torch.Tensor],
62-
rhs_val: Union[int, float, TRTTensor, torch.Tensor],
61+
lhs_val: Union[int, float, bool, TRTTensor, torch.Tensor],
62+
rhs_val: Union[int, float, bool, TRTTensor, torch.Tensor],
6363
) -> TRTTensor:
6464
"""
6565
This function adds a TensorRT elementwise layer. We allow both operands to be
@@ -120,11 +120,11 @@ def convert_binary_elementwise(
120120
# Note that the dtype here is supposed to be the same as the scalar
121121
# dtype but we don't have a way to detect whether it makes sense for the
122122
# scalar to be float or half. Hence we go with the lhs dtype.
123-
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)):
123+
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int, bool)):
124124
rhs_val = np.array(
125125
[rhs_val], dtype=unified_dtype_converter(lhs_dtype, Frameworks.NUMPY)
126126
)
127-
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)):
127+
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int, bool)):
128128
lhs_val = np.array(
129129
[lhs_val], dtype=unified_dtype_converter(rhs_dtype, Frameworks.NUMPY)
130130
)

0 commit comments

Comments
 (0)