@@ -46,31 +46,22 @@ def get_ir(target: Target) -> SourceIR:
4646 return SourceIR .UNKNOWN
4747
4848
49- @dynamo_tensorrt_converter (torch .ops .aten .native_batch_norm .default ) # type: ignore[misc]
50- def aten_ops_native_batch_norm (
51- ctx : ConversionContext ,
52- target : Target ,
53- args : Tuple [Argument , ...],
54- kwargs : Dict [str , Argument ],
55- name : str ,
56- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
57- return impl .normalization .native_batch_norm (
58- ctx ,
59- target ,
60- SourceIR .ATEN ,
61- name ,
62- input = args [0 ],
63- weight = args [1 ],
64- bias = args [2 ],
65- running_mean = args [3 ],
66- running_var = args [4 ],
67- training = args [5 ],
68- momentum = args [6 ],
69- eps = args [7 ],
49+ def one_user_validator (node : Node ) -> bool :
50+ # Validate only one user, which is a getitem node that accesses the first element in the list
51+ return (
52+ len (node .users ) == 1
53+ and list (node .users )[0 ].target == operator .getitem
54+ and list (node .users )[0 ].args [1 ] == 0
7055 )
7156
7257
73- @dynamo_tensorrt_converter (torch .ops .aten .batch_norm ) # type: ignore[misc]
58+ @dynamo_tensorrt_converter (torch .ops .aten .native_batch_norm .default , capability_validator = one_user_validator ) # type: ignore[misc]
59+ @dynamo_tensorrt_converter (torch .ops .aten .batch_norm .default ) # type: ignore[misc]
60+ @enforce_tensor_types (
61+ {
62+ 0 : (TRTTensor ,),
63+ }
64+ ) # type: ignore[misc]
7465def aten_ops_batch_norm (
7566 ctx : ConversionContext ,
7667 target : Target ,
@@ -91,32 +82,18 @@ def aten_ops_batch_norm(
9182 training = args [5 ],
9283 momentum = args [6 ],
9384 eps = args [7 ],
94- cudnn_enabled = args [8 ],
95- )
96-
97-
98- @dynamo_tensorrt_converter (torch .ops .aten .native_layer_norm .default ) # type: ignore[misc]
99- def aten_ops_native_layer_norm (
100- ctx : ConversionContext ,
101- target : Target ,
102- args : Tuple [Argument , ...],
103- kwargs : Dict [str , Argument ],
104- name : str ,
105- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
106- return impl .normalization .native_layer_norm (
107- ctx ,
108- target ,
109- SourceIR .ATEN ,
110- name ,
111- input = args [0 ],
112- normalized_shape = args [1 ],
113- weight = args [2 ],
114- bias = args [3 ],
115- eps = args [4 ],
85+ cudnn_enabled = args_bounds_check (args , 8 , True ),
86+ return_mean_rstd = (target == torch .ops .aten .native_batch_norm .default ),
11687 )
11788
11889
90+ @dynamo_tensorrt_converter (torch .ops .aten .native_layer_norm .default , capability_validator = one_user_validator ) # type: ignore[misc]
11991@dynamo_tensorrt_converter (torch .ops .aten .layer_norm .default ) # type: ignore[misc]
92+ @enforce_tensor_types (
93+ {
94+ 0 : (TRTTensor ,),
95+ }
96+ ) # type: ignore[misc]
12097def aten_ops_layer_norm (
12198 ctx : ConversionContext ,
12299 target : Target ,
@@ -135,10 +112,16 @@ def aten_ops_layer_norm(
135112 bias = args_bounds_check (args , 3 ),
136113 eps = args_bounds_check (args , 4 , 1e-05 ),
137114 cudnn_enable = args_bounds_check (args , 5 , True ),
115+ return_mean_rstd = (target == torch .ops .aten .native_layer_norm .default ),
138116 )
139117
140118
141- @dynamo_tensorrt_converter (torch .ops .aten .native_group_norm .default ) # type: ignore[misc]
119+ @dynamo_tensorrt_converter (torch .ops .aten .native_group_norm .default , capability_validator = one_user_validator ) # type: ignore[misc]
120+ @enforce_tensor_types (
121+ {
122+ 0 : (TRTTensor ,),
123+ }
124+ ) # type: ignore[misc]
142125def aten_ops_native_group_norm (
143126 ctx : ConversionContext ,
144127 target : Target ,
@@ -163,6 +146,11 @@ def aten_ops_native_group_norm(
163146
164147
165148@dynamo_tensorrt_converter (torch .ops .aten .group_norm .default ) # type: ignore[misc]
149+ @enforce_tensor_types (
150+ {
151+ 0 : (TRTTensor ,),
152+ }
153+ ) # type: ignore[misc]
166154def aten_ops_group_norm (
167155 ctx : ConversionContext ,
168156 target : Target ,
@@ -856,15 +844,6 @@ def aten_ops_prod(
856844 )
857845
858846
859- def one_user_validator (node : Node ) -> bool :
860- # Validate only one user, which is a getitem node that accesses the first element in the list
861- return (
862- len (node .users ) == 1
863- and list (node .users )[0 ].target == operator .getitem
864- and list (node .users )[0 ].args [1 ] == 0
865- )
866-
867-
868847@dynamo_tensorrt_converter (torch .ops .aten .max .default ) # type: ignore[misc]
869848@dynamo_tensorrt_converter (torch .ops .aten .max .dim , capability_validator = one_user_validator ) # type: ignore[misc]
870849def aten_ops_max (
0 commit comments