@@ -1182,11 +1182,14 @@ at::Tensor XLANativeFunctions::avg_pool2d(
11821182 at::IntArrayRef padding, bool ceil_mode, bool count_include_pad,
11831183 std::optional<int64_t > divisor_override) {
11841184 TORCH_LAZY_FN_COUNTER_TIMED_TRACING (" xla::" );
1185- XLA_ASSIGN_OR_THROW (XLATensorPtr xla_self, bridge::GetXlaTensor (self));
1186- return bridge::AtenFromXlaTensor (tensor_methods::avg_pool_nd (
1187- xla_self, /* spatial_dim_count=*/ 2 , XlaHelpers::I64List (kernel_size),
1188- XlaHelpers::I64List (stride), XlaHelpers::I64List (padding), ceil_mode,
1189- count_include_pad, divisor_override));
1185+ XLA_ASSIGN_OR_THROW (absl_nonnull XLATensorPtr xla_self,
1186+ bridge::GetXlaTensor (self));
1187+ XLA_ASSIGN_OR_THROW (
1188+ absl_nonnull XLATensorPtr output,
1189+ tensor_methods::avg_pool_nd (xla_self, /* spatial_dim_count=*/ 2 ,
1190+ kernel_size, stride, padding, ceil_mode,
1191+ count_include_pad, divisor_override));
1192+ return bridge::AtenFromXlaTensor (std::move (output));
11901193}
11911194
11921195at::Tensor XLANativeFunctions::avg_pool2d_backward (
@@ -1203,25 +1206,31 @@ at::Tensor XLANativeFunctions::avg_pool2d_backward(
12031206 count_include_pad,
12041207 divisor_override);
12051208 }
1206- XLA_ASSIGN_OR_THROW (XLATensorPtr xla_grad_output,
1209+ XLA_ASSIGN_OR_THROW (absl_nonnull XLATensorPtr xla_grad_output,
12071210 bridge::GetXlaTensor (grad_output));
1208- XLA_ASSIGN_OR_THROW (XLATensorPtr xla_self, bridge::GetXlaTensor (self));
1209- return bridge::AtenFromXlaTensor (tensor_methods::avg_pool_nd_backward (
1210- xla_grad_output, xla_self, /* spatial_dim_count=*/ 2 ,
1211- XlaHelpers::I64List (kernel_size), XlaHelpers::I64List (stride),
1212- XlaHelpers::I64List (padding), ceil_mode, count_include_pad));
1211+ XLA_ASSIGN_OR_THROW (absl_nonnull XLATensorPtr xla_self,
1212+ bridge::GetXlaTensor (self));
1213+ XLA_ASSIGN_OR_THROW (
1214+ absl_nonnull XLATensorPtr output,
1215+ tensor_methods::avg_pool_nd_backward (
1216+ xla_grad_output, xla_self, /* spatial_dim_count=*/ 2 , kernel_size,
1217+ stride, padding, ceil_mode, count_include_pad));
1218+ return bridge::AtenFromXlaTensor (std::move (output));
12131219}
12141220
12151221at::Tensor XLANativeFunctions::avg_pool3d (
12161222 const at::Tensor& self, at::IntArrayRef kernel_size, at::IntArrayRef stride,
12171223 at::IntArrayRef padding, bool ceil_mode, bool count_include_pad,
12181224 std::optional<int64_t > divisor_override) {
12191225 TORCH_LAZY_FN_COUNTER_TIMED_TRACING (" xla::" );
1220- XLA_ASSIGN_OR_THROW (XLATensorPtr xla_self, bridge::GetXlaTensor (self));
1221- return bridge::AtenFromXlaTensor (tensor_methods::avg_pool_nd (
1222- xla_self, /* spatial_dim_count=*/ 3 , XlaHelpers::I64List (kernel_size),
1223- XlaHelpers::I64List (stride), XlaHelpers::I64List (padding), ceil_mode,
1224- count_include_pad, divisor_override));
1226+ XLA_ASSIGN_OR_THROW (absl_nonnull XLATensorPtr xla_self,
1227+ bridge::GetXlaTensor (self));
1228+ XLA_ASSIGN_OR_THROW (
1229+ absl_nonnull XLATensorPtr output,
1230+ tensor_methods::avg_pool_nd (xla_self, /* spatial_dim_count=*/ 3 ,
1231+ kernel_size, stride, padding, ceil_mode,
1232+ count_include_pad, divisor_override));
1233+ return bridge::AtenFromXlaTensor (std::move (output));
12251234}
12261235
12271236at::Tensor XLANativeFunctions::avg_pool3d_backward (
@@ -1238,13 +1247,16 @@ at::Tensor XLANativeFunctions::avg_pool3d_backward(
12381247 count_include_pad,
12391248 divisor_override);
12401249 }
1241- XLA_ASSIGN_OR_THROW (XLATensorPtr xla_grad_output,
1250+ XLA_ASSIGN_OR_THROW (absl_nonnull XLATensorPtr xla_grad_output,
12421251 bridge::GetXlaTensor (grad_output));
1243- XLA_ASSIGN_OR_THROW (XLATensorPtr xla_self, bridge::GetXlaTensor (self));
1244- return bridge::AtenFromXlaTensor (tensor_methods::avg_pool_nd_backward (
1245- xla_grad_output, xla_self, /* spatial_dim_count=*/ 3 ,
1246- XlaHelpers::I64List (kernel_size), XlaHelpers::I64List (stride),
1247- XlaHelpers::I64List (padding), ceil_mode, count_include_pad));
1252+ XLA_ASSIGN_OR_THROW (absl_nonnull XLATensorPtr xla_self,
1253+ bridge::GetXlaTensor (self));
1254+ XLA_ASSIGN_OR_THROW (
1255+ absl_nonnull XLATensorPtr output,
1256+ tensor_methods::avg_pool_nd_backward (
1257+ xla_grad_output, xla_self, /* spatial_dim_count=*/ 3 , kernel_size,
1258+ stride, padding, ceil_mode, count_include_pad));
1259+ return bridge::AtenFromXlaTensor (std::move (output));
12481260}
12491261
12501262at::Tensor XLANativeFunctions::baddbmm (const at::Tensor& self,
@@ -2327,12 +2339,14 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::max_pool2d_with_indices(
23272339 dilation,
23282340 ceil_mode);
23292341 }
2330- XLA_ASSIGN_OR_THROW (XLATensorPtr xla_self, bridge::GetXlaTensor (self));
2331- auto outputs = tensor_methods::max_pool_nd (
2332- xla_self, /* spatial_dim_count=*/ 2 , XlaHelpers::I64List (kernel_size),
2333- XlaHelpers::I64List (stride), XlaHelpers::I64List (padding), ceil_mode);
2334- return std::make_tuple (bridge::AtenFromXlaTensor (std::get<0 >(outputs)),
2335- bridge::AtenFromXlaTensor (std::get<1 >(outputs)));
2342+ XLA_ASSIGN_OR_THROW (absl_nonnull XLATensorPtr xla_self,
2343+ bridge::GetXlaTensor (self));
2344+ std::tuple<absl_nonnull XLATensorPtr, absl_nonnull XLATensorPtr> output;
2345+ XLA_ASSIGN_OR_THROW (output, tensor_methods::max_pool_nd (
2346+ xla_self, /* spatial_dim_count=*/ 2 ,
2347+ kernel_size, stride, padding, ceil_mode));
2348+ return std::make_tuple (bridge::AtenFromXlaTensor (std::get<0 >(output)),
2349+ bridge::AtenFromXlaTensor (std::get<1 >(output)));
23362350}
23372351
23382352at::Tensor XLANativeFunctions::max_pool2d_with_indices_backward (
@@ -2350,13 +2364,15 @@ at::Tensor XLANativeFunctions::max_pool2d_with_indices_backward(
23502364 padding, dilation,
23512365 ceil_mode, indices);
23522366 }
2353- XLA_ASSIGN_OR_THROW (XLATensorPtr xla_grad_output,
2367+ XLA_ASSIGN_OR_THROW (absl_nonnull XLATensorPtr xla_grad_output,
23542368 bridge::GetXlaTensor (grad_output));
2355- XLA_ASSIGN_OR_THROW (XLATensorPtr xla_self, bridge::GetXlaTensor (self));
2356- return bridge::AtenFromXlaTensor (tensor_methods::max_pool_nd_backward (
2357- xla_grad_output, xla_self, /* spatial_dim_count=*/ 2 ,
2358- XlaHelpers::I64List (kernel_size), XlaHelpers::I64List (stride),
2359- XlaHelpers::I64List (padding), ceil_mode));
2369+ XLA_ASSIGN_OR_THROW (absl_nonnull XLATensorPtr xla_self,
2370+ bridge::GetXlaTensor (self));
2371+ XLA_ASSIGN_OR_THROW (absl_nonnull XLATensorPtr output,
2372+ tensor_methods::max_pool_nd_backward (
2373+ xla_grad_output, xla_self, /* spatial_dim_count=*/ 2 ,
2374+ kernel_size, stride, padding, ceil_mode));
2375+ return bridge::AtenFromXlaTensor (std::move (output));
23602376}
23612377
23622378at::Tensor XLANativeFunctions::max_pool3d (
@@ -2382,13 +2398,15 @@ at::Tensor XLANativeFunctions::max_pool3d_with_indices_backward(
23822398 padding, dilation,
23832399 ceil_mode, indices);
23842400 }
2385- XLA_ASSIGN_OR_THROW (XLATensorPtr xla_grad_output,
2401+ XLA_ASSIGN_OR_THROW (absl_nonnull XLATensorPtr xla_grad_output,
23862402 bridge::GetXlaTensor (grad_output));
2387- XLA_ASSIGN_OR_THROW (XLATensorPtr xla_self, bridge::GetXlaTensor (self));
2388- return bridge::AtenFromXlaTensor (tensor_methods::max_pool_nd_backward (
2389- xla_grad_output, xla_self, /* spatial_dim_count=*/ 3 ,
2390- XlaHelpers::I64List (kernel_size), XlaHelpers::I64List (stride),
2391- XlaHelpers::I64List (padding), ceil_mode));
2403+ XLA_ASSIGN_OR_THROW (absl_nonnull XLATensorPtr xla_self,
2404+ bridge::GetXlaTensor (self));
2405+ XLA_ASSIGN_OR_THROW (absl_nonnull XLATensorPtr output,
2406+ tensor_methods::max_pool_nd_backward (
2407+ xla_grad_output, xla_self, /* spatial_dim_count=*/ 3 ,
2408+ kernel_size, stride, padding, ceil_mode));
2409+ return bridge::AtenFromXlaTensor (std::move (output));
23922410}
23932411
23942412std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::max_pool3d_with_indices (
@@ -2404,12 +2422,14 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::max_pool3d_with_indices(
24042422 dilation,
24052423 ceil_mode);
24062424 }
2407- XLA_ASSIGN_OR_THROW (XLATensorPtr xla_self, bridge::GetXlaTensor (self));
2408- auto outputs = tensor_methods::max_pool_nd (
2409- xla_self, /* spatial_dim_count=*/ 3 , XlaHelpers::I64List (kernel_size),
2410- XlaHelpers::I64List (stride), XlaHelpers::I64List (padding), ceil_mode);
2411- return std::make_tuple (bridge::AtenFromXlaTensor (std::get<0 >(outputs)),
2412- bridge::AtenFromXlaTensor (std::get<1 >(outputs)));
2425+ XLA_ASSIGN_OR_THROW (absl_nonnull XLATensorPtr xla_self,
2426+ bridge::GetXlaTensor (self));
2427+ std::tuple<absl_nonnull XLATensorPtr, absl_nonnull XLATensorPtr> output;
2428+ XLA_ASSIGN_OR_THROW (output, tensor_methods::max_pool_nd (
2429+ xla_self, /* spatial_dim_count=*/ 3 ,
2430+ kernel_size, stride, padding, ceil_mode));
2431+ return std::make_tuple (bridge::AtenFromXlaTensor (std::get<0 >(output)),
2432+ bridge::AtenFromXlaTensor (std::get<1 >(output)));
24132433}
24142434
24152435at::Tensor XLANativeFunctions::max_unpool2d (const at::Tensor& self,
0 commit comments