-
Notifications
You must be signed in to change notification settings - Fork 3
/
InferSize.h
52 lines (46 loc) · 1.67 KB
/
InferSize.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#pragma once
#include <c10/core/ScalarType.h>
#include <c10/util/Optional.h>
#include <sstream>
#include <vector>
namespace at {
// Infers the size of a dim with size -1, if it exists. Also checks that new
// shape is compatible with the number of elements.
inline std::vector<int64_t> infer_size(IntArrayRef shape, int64_t numel) {
auto res = shape.vec();
int64_t newsize = 1;
auto infer_dim = c10::optional<int64_t>();
for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) {
if (shape[dim] == -1) {
if (infer_dim) {
throw std::runtime_error("only one dimension can be inferred");
}
infer_dim = dim;
} else if (shape[dim] >= 0) {
newsize *= shape[dim];
} else {
AT_ERROR("invalid shape dimension ", shape[dim]);
}
}
if (numel == newsize || (infer_dim && newsize > 0 && numel % newsize == 0)) {
if (infer_dim) {
// We have a degree of freedom here to select the dimension size; follow
// NumPy semantics and just bail. However, a nice error message is needed
// because users often use `view` as a way to flatten & unflatten
// dimensions and will otherwise be confused why
// empty_tensor.view( 0, 0)
// works yet
// empty_tensor.view(-1, 0)
// doesn't.
TORCH_CHECK(newsize != 0, "cannot reshape tensor of 0 elements into shape ",
shape, " because the unspecified dimension size -1 can be any "
"value and is ambiguous");
res[*infer_dim] = numel / newsize;
}
return res;
}
std::ostringstream ss;
ss << "shape '" << shape << "' is invalid for input of size " << numel;
throw std::runtime_error(ss.str());
}
}