@@ -1739,9 +1739,177 @@ def aten_ops_logical_xor(
1739
1739
)
1740
1740
1741
1741
1742
+ def bitwise_type_validator (node : Node ) -> bool :
1743
+ supported_type = [torch .bool , bool ]
1744
+
1745
+ tensor_targets = [
1746
+ torch .ops .aten .bitwise_and .Tensor ,
1747
+ torch .ops .aten .bitwise_or .Tensor ,
1748
+ torch .ops .aten .bitwise_xor .Tensor ,
1749
+ ]
1750
+ scalar_targets = [
1751
+ torch .ops .aten .bitwise_and .Scalar ,
1752
+ torch .ops .aten .bitwise_or .Scalar ,
1753
+ torch .ops .aten .bitwise_xor .Scalar ,
1754
+ ]
1755
+ scalar_tensor_targets = [
1756
+ torch .ops .aten .bitwise_and .Scalar_Tensor ,
1757
+ torch .ops .aten .bitwise_or .Scalar_Tensor ,
1758
+ torch .ops .aten .bitwise_xor .Scalar_Tensor ,
1759
+ ]
1760
+
1761
+ if node .target in tensor_targets :
1762
+ lhs_val = node .args [0 ]
1763
+ rhs_val = node .args [1 ]
1764
+ lhs_meta = lhs_val .meta .get ("tensor_meta" )
1765
+ rhs_meta = rhs_val .meta .get ("tensor_meta" )
1766
+ if lhs_meta is None or rhs_meta is None :
1767
+ return False
1768
+ return lhs_meta .dtype in supported_type and rhs_meta .dtype in supported_type
1769
+
1770
+ elif node .target in scalar_targets :
1771
+ lhs_val = node .args [0 ]
1772
+ rhs_val = node .args [1 ]
1773
+ lhs_meta = lhs_val .meta .get ("tensor_meta" )
1774
+ if lhs_meta is None :
1775
+ return False
1776
+ return lhs_meta .dtype in supported_type and isinstance (rhs_val , bool )
1777
+
1778
+ elif node .target in scalar_tensor_targets :
1779
+ lhs_val = node .args [0 ]
1780
+ rhs_val = node .args [1 ]
1781
+ rhs_meta = rhs_val .meta .get ("tensor_meta" )
1782
+ if rhs_meta is None :
1783
+ return False
1784
+ return isinstance (lhs_val , bool ) and rhs_meta .dtype in supported_type
1785
+
1786
+ else :
1787
+ return False
1788
+
1789
+
1790
+ @dynamo_tensorrt_converter (
1791
+ torch .ops .aten .bitwise_and .Tensor , capability_validator = bitwise_type_validator
1792
+ )
1793
+ @dynamo_tensorrt_converter (
1794
+ torch .ops .aten .bitwise_and .Scalar , capability_validator = bitwise_type_validator
1795
+ )
1796
+ @dynamo_tensorrt_converter (
1797
+ torch .ops .aten .bitwise_and .Scalar_Tensor ,
1798
+ capability_validator = bitwise_type_validator ,
1799
+ )
1800
+ def aten_ops_bitwise_and (
1801
+ ctx : ConversionContext ,
1802
+ target : Target ,
1803
+ args : Tuple [Argument , ...],
1804
+ kwargs : Dict [str , Argument ],
1805
+ name : str ,
1806
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
1807
+ return impl .elementwise .bitwise_and (
1808
+ ctx ,
1809
+ target ,
1810
+ SourceIR .ATEN ,
1811
+ name ,
1812
+ args [0 ],
1813
+ args [1 ],
1814
+ )
1815
+
1816
+
1817
+ @dynamo_tensorrt_converter (
1818
+ torch .ops .aten .bitwise_or .Tensor , capability_validator = bitwise_type_validator
1819
+ )
1820
+ @dynamo_tensorrt_converter (
1821
+ torch .ops .aten .bitwise_or .Scalar , capability_validator = bitwise_type_validator
1822
+ )
1823
+ @dynamo_tensorrt_converter (
1824
+ torch .ops .aten .bitwise_or .Scalar_Tensor , capability_validator = bitwise_type_validator
1825
+ )
1826
+ def aten_ops_bitwise_or (
1827
+ ctx : ConversionContext ,
1828
+ target : Target ,
1829
+ args : Tuple [Argument , ...],
1830
+ kwargs : Dict [str , Argument ],
1831
+ name : str ,
1832
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
1833
+ return impl .elementwise .bitwise_or (
1834
+ ctx ,
1835
+ target ,
1836
+ SourceIR .ATEN ,
1837
+ name ,
1838
+ args [0 ],
1839
+ args [1 ],
1840
+ )
1841
+
1842
+
1843
+ @dynamo_tensorrt_converter (
1844
+ torch .ops .aten .bitwise_xor .Tensor , capability_validator = bitwise_type_validator
1845
+ )
1846
+ @dynamo_tensorrt_converter (
1847
+ torch .ops .aten .bitwise_xor .Scalar , capability_validator = bitwise_type_validator
1848
+ )
1849
+ @dynamo_tensorrt_converter (
1850
+ torch .ops .aten .bitwise_xor .Scalar_Tensor ,
1851
+ capability_validator = bitwise_type_validator ,
1852
+ )
1853
+ def aten_ops_bitwise_xor (
1854
+ ctx : ConversionContext ,
1855
+ target : Target ,
1856
+ args : Tuple [Argument , ...],
1857
+ kwargs : Dict [str , Argument ],
1858
+ name : str ,
1859
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
1860
+ return impl .elementwise .bitwise_xor (
1861
+ ctx ,
1862
+ target ,
1863
+ SourceIR .ATEN ,
1864
+ name ,
1865
+ args [0 ],
1866
+ args [1 ],
1867
+ )
1868
+
1869
+
1870
+ def bitwise_not_type_validator (node : Node ) -> bool :
1871
+ val = node .args [0 ]
1872
+ val_meta = val .meta .get ("tensor_meta" )
1873
+
1874
+ if val_meta is None :
1875
+ return False
1876
+
1877
+ supported_type = [torch .bool , bool ]
1878
+ return val_meta .dtype in supported_type
1879
+
1880
+
1881
+ @dynamo_tensorrt_converter (
1882
+ torch .ops .aten .bitwise_not .default , capability_validator = bitwise_not_type_validator
1883
+ )
1884
+ @enforce_tensor_types (
1885
+ {
1886
+ 0 : (TRTTensor ,),
1887
+ }
1888
+ )
1889
+ def aten_ops_bitwise_not (
1890
+ ctx : ConversionContext ,
1891
+ target : Target ,
1892
+ args : Tuple [Argument , ...],
1893
+ kwargs : Dict [str , Argument ],
1894
+ name : str ,
1895
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
1896
+ return impl .unary .bitwise_not (
1897
+ ctx ,
1898
+ target ,
1899
+ SourceIR .ATEN ,
1900
+ name ,
1901
+ args [0 ],
1902
+ )
1903
+
1904
+
1742
1905
@dynamo_tensorrt_converter (torch .ops .aten .eq .Tensor )
1743
1906
@dynamo_tensorrt_converter (torch .ops .aten .eq .Scalar )
1744
- def aten_ops_equal (
1907
+ @enforce_tensor_types (
1908
+ {
1909
+ 0 : (TRTTensor ,),
1910
+ }
1911
+ )
1912
+ def aten_ops_eq (
1745
1913
ctx : ConversionContext ,
1746
1914
target : Target ,
1747
1915
args : Tuple [Argument , ...],
@@ -1758,9 +1926,38 @@ def aten_ops_equal(
1758
1926
)
1759
1927
1760
1928
1929
+ @dynamo_tensorrt_converter (torch .ops .aten .ne .Tensor )
1930
+ @dynamo_tensorrt_converter (torch .ops .aten .ne .Scalar )
1931
+ @enforce_tensor_types (
1932
+ {
1933
+ 0 : (TRTTensor ,),
1934
+ }
1935
+ )
1936
+ def aten_ops_ne (
1937
+ ctx : ConversionContext ,
1938
+ target : Target ,
1939
+ args : Tuple [Argument , ...],
1940
+ kwargs : Dict [str , Argument ],
1941
+ name : str ,
1942
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
1943
+ return impl .elementwise .ne (
1944
+ ctx ,
1945
+ target ,
1946
+ SourceIR .ATEN ,
1947
+ name ,
1948
+ args [0 ],
1949
+ args [1 ],
1950
+ )
1951
+
1952
+
1761
1953
@dynamo_tensorrt_converter (torch .ops .aten .gt .Tensor )
1762
1954
@dynamo_tensorrt_converter (torch .ops .aten .gt .Scalar )
1763
- def aten_ops_greater (
1955
+ @enforce_tensor_types (
1956
+ {
1957
+ 0 : (TRTTensor ,),
1958
+ }
1959
+ )
1960
+ def aten_ops_gt (
1764
1961
ctx : ConversionContext ,
1765
1962
target : Target ,
1766
1963
args : Tuple [Argument , ...],
@@ -1777,9 +1974,38 @@ def aten_ops_greater(
1777
1974
)
1778
1975
1779
1976
1977
+ @dynamo_tensorrt_converter (torch .ops .aten .ge .Tensor )
1978
+ @dynamo_tensorrt_converter (torch .ops .aten .ge .Scalar )
1979
+ @enforce_tensor_types (
1980
+ {
1981
+ 0 : (TRTTensor ,),
1982
+ }
1983
+ )
1984
+ def aten_ops_ge (
1985
+ ctx : ConversionContext ,
1986
+ target : Target ,
1987
+ args : Tuple [Argument , ...],
1988
+ kwargs : Dict [str , Argument ],
1989
+ name : str ,
1990
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
1991
+ return impl .elementwise .ge (
1992
+ ctx ,
1993
+ target ,
1994
+ SourceIR .ATEN ,
1995
+ name ,
1996
+ args [0 ],
1997
+ args [1 ],
1998
+ )
1999
+
2000
+
1780
2001
@dynamo_tensorrt_converter (torch .ops .aten .lt .Tensor )
1781
2002
@dynamo_tensorrt_converter (torch .ops .aten .lt .Scalar )
1782
- def aten_ops_less (
2003
+ @enforce_tensor_types (
2004
+ {
2005
+ 0 : (TRTTensor ,),
2006
+ }
2007
+ )
2008
+ def aten_ops_lt (
1783
2009
ctx : ConversionContext ,
1784
2010
target : Target ,
1785
2011
args : Tuple [Argument , ...],
@@ -1796,6 +2022,30 @@ def aten_ops_less(
1796
2022
)
1797
2023
1798
2024
2025
+ @dynamo_tensorrt_converter (torch .ops .aten .le .Tensor )
2026
+ @dynamo_tensorrt_converter (torch .ops .aten .le .Scalar )
2027
+ @enforce_tensor_types (
2028
+ {
2029
+ 0 : (TRTTensor ,),
2030
+ }
2031
+ )
2032
+ def aten_ops_le (
2033
+ ctx : ConversionContext ,
2034
+ target : Target ,
2035
+ args : Tuple [Argument , ...],
2036
+ kwargs : Dict [str , Argument ],
2037
+ name : str ,
2038
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
2039
+ return impl .elementwise .le (
2040
+ ctx ,
2041
+ target ,
2042
+ SourceIR .ATEN ,
2043
+ name ,
2044
+ args [0 ],
2045
+ args [1 ],
2046
+ )
2047
+
2048
+
1799
2049
def conv_param_validator (conv_node : Node ) -> bool :
1800
2050
return conv_node .args [7 ] in ([0 ], [0 , 0 ], [0 , 0 , 0 ])
1801
2051
0 commit comments