-
Notifications
You must be signed in to change notification settings - Fork 3
/
ExpandUtils.cpp
88 lines (78 loc) · 2.69 KB
/
ExpandUtils.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#include <ATen/ExpandUtils.h>
namespace at {
// NOTE: are_expandable did a similar check, please keep them sync if change is needed
std::vector<int64_t> infer_size(IntArrayRef a, IntArrayRef b) {
size_t dimsA = a.size();
size_t dimsB = b.size();
size_t ndim = dimsA > dimsB ? dimsA : dimsB;
std::vector<int64_t> expandedSizes(ndim);
// Use ptrdiff_t to ensure signed comparison.
for (ptrdiff_t i = (ptrdiff_t)ndim - 1; i >= 0; --i) {
ptrdiff_t offset = ndim - 1 - i;
ptrdiff_t dimA = dimsA - 1 - offset;
ptrdiff_t dimB = dimsB - 1 - offset;
int64_t sizeA = (dimA >= 0) ? a[dimA] : 1;
int64_t sizeB = (dimB >= 0) ? b[dimB] : 1;
TORCH_CHECK(
sizeA == sizeB || sizeA == 1 || sizeB == 1,
"The size of tensor a (", sizeA,
") must match the size of tensor b (", sizeB,
") at non-singleton dimension ", i);
// 1s map to the other size (even 0).
expandedSizes[i] = sizeA == 1 ? sizeB : sizeA;
}
return expandedSizes;
}
std::tuple<std::vector<int64_t>, std::vector<int64_t>> inferExpandGeometry(
IntArrayRef tensor_sizes,
IntArrayRef tensor_strides,
IntArrayRef sizes) {
int64_t ndim = sizes.size();
int64_t tensor_dim = tensor_sizes.size();
if (tensor_dim == 0) {
std::vector<int64_t> expandedStrides(ndim, 0);
return std::tuple<std::vector<int64_t>, std::vector<int64_t>>(
sizes.vec(), expandedStrides);
}
std::vector<int64_t> expandedSizes(ndim);
std::vector<int64_t> expandedStrides(ndim);
// create a new geometry for the tensors
for (int64_t i = ndim - 1; i >= 0; --i) {
int64_t offset = ndim - 1 - i;
int64_t dim = tensor_dim - 1 - offset;
int64_t size = (dim >= 0) ? tensor_sizes[dim] : 1;
int64_t stride = (dim >= 0) ? tensor_strides[dim]
: expandedSizes[i + 1] * expandedStrides[i + 1];
int64_t targetSize = sizes[i];
if (targetSize == -1) {
TORCH_CHECK(
dim >= 0,
"The expanded size of the tensor (",
targetSize,
") isn't allowed in a leading, non-existing dimension ",
i);
targetSize = size;
}
if (size != targetSize) {
TORCH_CHECK(
size == 1,
"The expanded size of the tensor (",
targetSize,
") must match the existing size (",
size,
") at non-singleton dimension ",
i,
". Target sizes: ",
sizes,
". Tensor sizes: ",
tensor_sizes);
size = targetSize;
stride = 0;
}
expandedSizes[i] = size;
expandedStrides[i] = stride;
}
return std::tuple<std::vector<int64_t>, std::vector<int64_t>>(
expandedSizes, expandedStrides);
}
} // namespace at