Skip to content

Commit 12f35bf

Browse files
committed
Improve error handling and error messages for *_pool*d.
1 parent 2329746 commit 12f35bf

File tree

5 files changed

+277
-167
lines changed

5 files changed

+277
-167
lines changed

test/cpp/test_tensor.cpp

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <limits>
55
#include <vector>
66

7+
#include "absl/base/nullability.h"
78
#include "test/cpp/cpp_test_util.h"
89
#include "test/cpp/torch_xla_test.h"
910
#include "torch/csrc/autograd/variable.h"
@@ -297,14 +298,18 @@ TEST_F(TensorTest, TestMaxPool2D) {
297298
/*padding=*/{padding, padding}, /*dilation=*/{1, 1},
298299
/*ceil_mode=*/false);
299300
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
300-
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input,
301+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr dev_input,
301302
XLATensor::Create(input, device));
302-
auto dev_output = tensor_methods::max_pool_nd(
303-
dev_input,
304-
/*spatial_dim_count=*/2,
305-
/*kernel_size=*/{kernel_size, kernel_size},
306-
/*stride=*/{stride, stride},
307-
/*padding=*/{padding, padding}, /*ceil_mode=*/false);
303+
std::tuple<absl_nonnull XLATensorPtr, absl_nonnull XLATensorPtr>
304+
dev_output;
305+
XLA_ASSIGN_OR_THROW(
306+
dev_output,
307+
tensor_methods::max_pool_nd(
308+
dev_input,
309+
/*spatial_dim_count=*/2,
310+
/*kernel_size=*/{kernel_size, kernel_size},
311+
/*stride=*/{stride, stride},
312+
/*padding=*/{padding, padding}, /*ceil_mode=*/false));
308313
AllClose(output, std::get<0>(dev_output));
309314
});
310315
}
@@ -322,15 +327,18 @@ TEST_F(TensorTest, TestMaxPool2DNonSquare) {
322327
/*padding=*/{padding, padding + 1}, /*dilation=*/{1, 1},
323328
/*ceil_mode=*/false);
324329
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
325-
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input,
330+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr dev_input,
326331
XLATensor::Create(input, device));
327-
auto dev_output = tensor_methods::max_pool_nd(
328-
dev_input,
329-
/*spatial_dim_count=*/2,
330-
/*kernel_size=*/{kernel_size, kernel_size + 1},
331-
/*stride=*/{stride, stride + 1},
332-
/*padding=*/{padding, padding + 1},
333-
/*ceil_mode=*/false);
332+
std::tuple<absl_nonnull XLATensorPtr, absl_nonnull XLATensorPtr>
333+
dev_output;
334+
XLA_ASSIGN_OR_THROW(dev_output,
335+
tensor_methods::max_pool_nd(
336+
dev_input,
337+
/*spatial_dim_count=*/2,
338+
/*kernel_size=*/{kernel_size, kernel_size + 1},
339+
/*stride=*/{stride, stride + 1},
340+
/*padding=*/{padding, padding + 1},
341+
/*ceil_mode=*/false));
334342
AllClose(output, std::get<0>(dev_output));
335343
});
336344
}
@@ -351,16 +359,17 @@ TEST_F(TensorTest, TestAvgPool2D) {
351359
/*ceil_mode=*/false, count_include_pad,
352360
/*divisor_override=*/std::nullopt);
353361
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
354-
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input,
362+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr dev_input,
355363
XLATensor::Create(input, device));
356-
XLATensorPtr dev_output = tensor_methods::avg_pool_nd(
357-
dev_input,
358-
/*spatial_dim_count=*/2,
359-
/*kernel_size=*/{kernel_size, kernel_size},
360-
/*stride=*/{stride, stride},
361-
/*padding=*/{padding, padding},
362-
/*ceil_mode=*/false, count_include_pad,
363-
/*divisor_override=*/std::nullopt);
364+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr dev_output,
365+
tensor_methods::avg_pool_nd(
366+
dev_input,
367+
/*spatial_dim_count=*/2,
368+
/*kernel_size=*/{kernel_size, kernel_size},
369+
/*stride=*/{stride, stride},
370+
/*padding=*/{padding, padding},
371+
/*ceil_mode=*/false, count_include_pad,
372+
/*divisor_override=*/std::nullopt));
364373
AllClose(output, dev_output);
365374
});
366375
}
@@ -382,17 +391,19 @@ TEST_F(TensorTest, TestAvgPool2DNonSquare) {
382391
/*count_include_pad=*/count_include_pad,
383392
/*divisor_override=*/std::nullopt);
384393
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
385-
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input,
394+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr dev_input,
386395
XLATensor::Create(input, device));
387-
XLATensorPtr dev_output = tensor_methods::avg_pool_nd(
388-
dev_input,
389-
/*spatial_dim_count=*/2,
390-
/*kernel_size=*/{kernel_size, kernel_size + 1},
391-
/*stride=*/{stride, stride + 1},
392-
/*padding=*/{padding, padding + 1},
393-
/*ceil_mode=*/false,
394-
/*count_include_pad=*/count_include_pad,
395-
/*divisor_override=*/std::nullopt);
396+
XLA_ASSIGN_OR_THROW(
397+
absl_nonnull XLATensorPtr dev_output,
398+
tensor_methods::avg_pool_nd(
399+
dev_input,
400+
/*spatial_dim_count=*/2,
401+
/*kernel_size=*/{kernel_size, kernel_size + 1},
402+
/*stride=*/{stride, stride + 1},
403+
/*padding=*/{padding, padding + 1},
404+
/*ceil_mode=*/false,
405+
/*count_include_pad=*/count_include_pad,
406+
/*divisor_override=*/std::nullopt));
396407
AllClose(output, dev_output);
397408
});
398409
}

torch_xla/csrc/aten_autograd_ops.cpp

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -192,11 +192,13 @@ torch::Tensor MaxPool3dAutogradFunction::forward(
192192
return std::get<0>(results);
193193
}
194194
ctx->save_for_backward({self});
195-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
196-
auto outputs = tensor_methods::max_pool_nd(
197-
xla_self, /*spatial_dim_count=*/3, XlaHelpers::I64List(kernel_size),
198-
XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode);
199-
return bridge::AtenFromXlaTensor(std::get<0>(outputs));
195+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
196+
bridge::GetXlaTensor(self));
197+
std::tuple<absl_nonnull XLATensorPtr, absl_nonnull XLATensorPtr> output;
198+
XLA_ASSIGN_OR_THROW(output, tensor_methods::max_pool_nd(
199+
xla_self, /*spatial_dim_count=*/3,
200+
kernel_size, stride, padding, ceil_mode));
201+
return bridge::AtenFromXlaTensor(std::get<0>(output));
200202
}
201203

202204
torch::autograd::variable_list MaxPool3dAutogradFunction::backward(
@@ -220,13 +222,15 @@ torch::autograd::variable_list MaxPool3dAutogradFunction::backward(
220222
padding, dilation,
221223
ceil_mode, indices);
222224
}
223-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output_0,
225+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_grad_output_0,
224226
bridge::GetXlaTensor(grad_output[0]));
225-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
226-
grad = bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward(
227-
xla_grad_output_0, xla_self, /*spatial_dim_count=*/3,
228-
XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride),
229-
XlaHelpers::I64List(padding), ceil_mode));
227+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
228+
bridge::GetXlaTensor(self));
229+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
230+
tensor_methods::max_pool_nd_backward(
231+
xla_grad_output_0, xla_self, /*spatial_dim_count=*/3,
232+
kernel_size, stride, padding, ceil_mode));
233+
grad = bridge::AtenFromXlaTensor(std::move(output));
230234

231235
torch::Tensor undef;
232236
torch::autograd::variable_list grad_inputs = {grad, undef, undef,
@@ -239,24 +243,28 @@ torch::Tensor max_pool2d_forward(torch::Tensor self,
239243
torch::IntArrayRef stride,
240244
torch::IntArrayRef padding,
241245
torch::IntArrayRef dilation, bool ceil_mode) {
242-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
243-
auto outputs = tensor_methods::max_pool_nd(
244-
xla_self, /*spatial_dim_count=*/2, XlaHelpers::I64List(kernel_size),
245-
XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode);
246-
return bridge::AtenFromXlaTensor(std::get<0>(outputs));
246+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
247+
bridge::GetXlaTensor(self));
248+
std::tuple<absl_nonnull XLATensorPtr, absl_nonnull XLATensorPtr> output;
249+
XLA_ASSIGN_OR_THROW(output, tensor_methods::max_pool_nd(
250+
xla_self, /*spatial_dim_count=*/2,
251+
kernel_size, stride, padding, ceil_mode));
252+
return bridge::AtenFromXlaTensor(std::get<0>(output));
247253
}
248254

249255
torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self,
250256
torch::IntArrayRef kernel_size,
251257
torch::IntArrayRef stride,
252258
torch::IntArrayRef padding, bool ceil_mode) {
253-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output,
259+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_grad_output,
254260
bridge::GetXlaTensor(grad_output));
255-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
256-
auto grad = bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward(
257-
xla_grad_output, xla_self, /*spatial_dim_count=*/2,
258-
XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride),
259-
XlaHelpers::I64List(padding), ceil_mode));
261+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
262+
bridge::GetXlaTensor(self));
263+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
264+
tensor_methods::max_pool_nd_backward(
265+
xla_grad_output, xla_self, /*spatial_dim_count=*/2,
266+
kernel_size, stride, padding, ceil_mode));
267+
auto grad = bridge::AtenFromXlaTensor(std::move(output));
260268
return grad;
261269
}
262270

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 66 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,11 +1182,14 @@ at::Tensor XLANativeFunctions::avg_pool2d(
11821182
at::IntArrayRef padding, bool ceil_mode, bool count_include_pad,
11831183
std::optional<int64_t> divisor_override) {
11841184
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
1185-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
1186-
return bridge::AtenFromXlaTensor(tensor_methods::avg_pool_nd(
1187-
xla_self, /*spatial_dim_count=*/2, XlaHelpers::I64List(kernel_size),
1188-
XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode,
1189-
count_include_pad, divisor_override));
1185+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
1186+
bridge::GetXlaTensor(self));
1187+
XLA_ASSIGN_OR_THROW(
1188+
absl_nonnull XLATensorPtr output,
1189+
tensor_methods::avg_pool_nd(xla_self, /*spatial_dim_count=*/2,
1190+
kernel_size, stride, padding, ceil_mode,
1191+
count_include_pad, divisor_override));
1192+
return bridge::AtenFromXlaTensor(std::move(output));
11901193
}
11911194

11921195
at::Tensor XLANativeFunctions::avg_pool2d_backward(
@@ -1203,25 +1206,31 @@ at::Tensor XLANativeFunctions::avg_pool2d_backward(
12031206
count_include_pad,
12041207
divisor_override);
12051208
}
1206-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output,
1209+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_grad_output,
12071210
bridge::GetXlaTensor(grad_output));
1208-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
1209-
return bridge::AtenFromXlaTensor(tensor_methods::avg_pool_nd_backward(
1210-
xla_grad_output, xla_self, /*spatial_dim_count=*/2,
1211-
XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride),
1212-
XlaHelpers::I64List(padding), ceil_mode, count_include_pad));
1211+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
1212+
bridge::GetXlaTensor(self));
1213+
XLA_ASSIGN_OR_THROW(
1214+
absl_nonnull XLATensorPtr output,
1215+
tensor_methods::avg_pool_nd_backward(
1216+
xla_grad_output, xla_self, /*spatial_dim_count=*/2, kernel_size,
1217+
stride, padding, ceil_mode, count_include_pad));
1218+
return bridge::AtenFromXlaTensor(std::move(output));
12131219
}
12141220

12151221
at::Tensor XLANativeFunctions::avg_pool3d(
12161222
const at::Tensor& self, at::IntArrayRef kernel_size, at::IntArrayRef stride,
12171223
at::IntArrayRef padding, bool ceil_mode, bool count_include_pad,
12181224
std::optional<int64_t> divisor_override) {
12191225
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
1220-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
1221-
return bridge::AtenFromXlaTensor(tensor_methods::avg_pool_nd(
1222-
xla_self, /*spatial_dim_count=*/3, XlaHelpers::I64List(kernel_size),
1223-
XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode,
1224-
count_include_pad, divisor_override));
1226+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
1227+
bridge::GetXlaTensor(self));
1228+
XLA_ASSIGN_OR_THROW(
1229+
absl_nonnull XLATensorPtr output,
1230+
tensor_methods::avg_pool_nd(xla_self, /*spatial_dim_count=*/3,
1231+
kernel_size, stride, padding, ceil_mode,
1232+
count_include_pad, divisor_override));
1233+
return bridge::AtenFromXlaTensor(std::move(output));
12251234
}
12261235

12271236
at::Tensor XLANativeFunctions::avg_pool3d_backward(
@@ -1238,13 +1247,16 @@ at::Tensor XLANativeFunctions::avg_pool3d_backward(
12381247
count_include_pad,
12391248
divisor_override);
12401249
}
1241-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output,
1250+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_grad_output,
12421251
bridge::GetXlaTensor(grad_output));
1243-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
1244-
return bridge::AtenFromXlaTensor(tensor_methods::avg_pool_nd_backward(
1245-
xla_grad_output, xla_self, /*spatial_dim_count=*/3,
1246-
XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride),
1247-
XlaHelpers::I64List(padding), ceil_mode, count_include_pad));
1252+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
1253+
bridge::GetXlaTensor(self));
1254+
XLA_ASSIGN_OR_THROW(
1255+
absl_nonnull XLATensorPtr output,
1256+
tensor_methods::avg_pool_nd_backward(
1257+
xla_grad_output, xla_self, /*spatial_dim_count=*/3, kernel_size,
1258+
stride, padding, ceil_mode, count_include_pad));
1259+
return bridge::AtenFromXlaTensor(std::move(output));
12481260
}
12491261

12501262
at::Tensor XLANativeFunctions::baddbmm(const at::Tensor& self,
@@ -2327,12 +2339,14 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::max_pool2d_with_indices(
23272339
dilation,
23282340
ceil_mode);
23292341
}
2330-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
2331-
auto outputs = tensor_methods::max_pool_nd(
2332-
xla_self, /*spatial_dim_count=*/2, XlaHelpers::I64List(kernel_size),
2333-
XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode);
2334-
return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)),
2335-
bridge::AtenFromXlaTensor(std::get<1>(outputs)));
2342+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
2343+
bridge::GetXlaTensor(self));
2344+
std::tuple<absl_nonnull XLATensorPtr, absl_nonnull XLATensorPtr> output;
2345+
XLA_ASSIGN_OR_THROW(output, tensor_methods::max_pool_nd(
2346+
xla_self, /*spatial_dim_count=*/2,
2347+
kernel_size, stride, padding, ceil_mode));
2348+
return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(output)),
2349+
bridge::AtenFromXlaTensor(std::get<1>(output)));
23362350
}
23372351

23382352
at::Tensor XLANativeFunctions::max_pool2d_with_indices_backward(
@@ -2350,13 +2364,15 @@ at::Tensor XLANativeFunctions::max_pool2d_with_indices_backward(
23502364
padding, dilation,
23512365
ceil_mode, indices);
23522366
}
2353-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output,
2367+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_grad_output,
23542368
bridge::GetXlaTensor(grad_output));
2355-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
2356-
return bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward(
2357-
xla_grad_output, xla_self, /*spatial_dim_count=*/2,
2358-
XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride),
2359-
XlaHelpers::I64List(padding), ceil_mode));
2369+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
2370+
bridge::GetXlaTensor(self));
2371+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
2372+
tensor_methods::max_pool_nd_backward(
2373+
xla_grad_output, xla_self, /*spatial_dim_count=*/2,
2374+
kernel_size, stride, padding, ceil_mode));
2375+
return bridge::AtenFromXlaTensor(std::move(output));
23602376
}
23612377

23622378
at::Tensor XLANativeFunctions::max_pool3d(
@@ -2382,13 +2398,15 @@ at::Tensor XLANativeFunctions::max_pool3d_with_indices_backward(
23822398
padding, dilation,
23832399
ceil_mode, indices);
23842400
}
2385-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output,
2401+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_grad_output,
23862402
bridge::GetXlaTensor(grad_output));
2387-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
2388-
return bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward(
2389-
xla_grad_output, xla_self, /*spatial_dim_count=*/3,
2390-
XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride),
2391-
XlaHelpers::I64List(padding), ceil_mode));
2403+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
2404+
bridge::GetXlaTensor(self));
2405+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
2406+
tensor_methods::max_pool_nd_backward(
2407+
xla_grad_output, xla_self, /*spatial_dim_count=*/3,
2408+
kernel_size, stride, padding, ceil_mode));
2409+
return bridge::AtenFromXlaTensor(std::move(output));
23922410
}
23932411

23942412
std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::max_pool3d_with_indices(
@@ -2404,12 +2422,14 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::max_pool3d_with_indices(
24042422
dilation,
24052423
ceil_mode);
24062424
}
2407-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
2408-
auto outputs = tensor_methods::max_pool_nd(
2409-
xla_self, /*spatial_dim_count=*/3, XlaHelpers::I64List(kernel_size),
2410-
XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode);
2411-
return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)),
2412-
bridge::AtenFromXlaTensor(std::get<1>(outputs)));
2425+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
2426+
bridge::GetXlaTensor(self));
2427+
std::tuple<absl_nonnull XLATensorPtr, absl_nonnull XLATensorPtr> output;
2428+
XLA_ASSIGN_OR_THROW(output, tensor_methods::max_pool_nd(
2429+
xla_self, /*spatial_dim_count=*/3,
2430+
kernel_size, stride, padding, ceil_mode));
2431+
return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(output)),
2432+
bridge::AtenFromXlaTensor(std::get<1>(output)));
24132433
}
24142434

24152435
at::Tensor XLANativeFunctions::max_unpool2d(const at::Tensor& self,

0 commit comments

Comments
 (0)