forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTensorGeometry.cpp
40 lines (34 loc) · 1 KB
/
TensorGeometry.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#include <ATen/TensorGeometry.h>
#include <limits>
#include <cstddef>
namespace at {
// See TensorGeometry.h on why this is useful now that we cache is_contiguous.
template <typename T>
bool _geometry_is_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides) {
assert(!overflows<std::int64_t>(sizes.size()));
auto dim = static_cast<std::int64_t>(sizes.size());
T expected_stride = 1;
bool contig_if_nonempty = true;
for (int64_t i = dim - 1; i >= 0; i--) {
if (sizes[i] == 0) {
return true;
}
if (contig_if_nonempty) {
if (sizes[i] != 1 && strides[i] != expected_stride) {
contig_if_nonempty = false;
}
expected_stride *= sizes[i];
}
}
return contig_if_nonempty;
}
bool geometry_is_contiguous(IntArrayRef sizes, IntArrayRef strides) {
return _geometry_is_contiguous(sizes, strides);
}
bool TensorGeometry::is_contiguous() const {
if (numel_ == 0) {
return true;
}
return at::_geometry_is_contiguous<c10::SymInt>(sizes_, strides_);
}
} // namespace at