-
Notifications
You must be signed in to change notification settings - Fork 3
/
TensorUtils.h
154 lines (137 loc) · 5.21 KB
/
TensorUtils.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#pragma once
#include <ATen/Tensor.h>
#include <ATen/TensorGeometry.h>
#include <ATen/Utils.h>
// These functions are NOT in Utils.h, because this file has a dep on Tensor.h
namespace at {
// The following are utility functions for checking that arguments
// make sense. These are particularly useful for native functions,
// which do NO argument checking by default.
struct CAFFE2_API TensorArg {
Tensor tensor;
const char* name;
int pos; // 1-indexed
TensorArg(Tensor tensor, const char* name, int pos)
: tensor(std::move(tensor)), name(name), pos(pos) {}
const Tensor* operator->() const { return &tensor; }
const Tensor& operator*() const { return tensor; }
};
struct CAFFE2_API TensorGeometryArg {
TensorGeometry tensor;
const char* name;
int pos; // 1-indexed
/* implicit */ TensorGeometryArg(TensorArg arg)
: tensor(TensorGeometry{arg.tensor}), name(arg.name), pos(arg.pos) {}
TensorGeometryArg(TensorGeometry tensor, const char* name, int pos)
: tensor(tensor), name(name), pos(pos) {}
const TensorGeometry* operator->() const { return &tensor; }
const TensorGeometry& operator*() const { return tensor; }
};
// A string describing which function did checks on its input
// arguments.
// TODO: Consider generalizing this into a call stack.
using CheckedFrom = const char*;
// The undefined convention: singular operators assume their arguments
// are defined, but functions which take multiple tensors will
// implicitly filter out undefined tensors (to make it easier to perform
// tests which should apply if the tensor is defined, and should not
// otherwise.)
//
// NB: This means that the n-ary operators take lists of TensorArg,
// not TensorGeometryArg, because the Tensor to TensorGeometry
// conversion will blow up if you have undefined tensors.
CAFFE2_API std::ostream& operator<<(std::ostream& out, TensorGeometryArg t);
CAFFE2_API void checkDim(
CheckedFrom c,
const TensorGeometryArg& t,
int64_t dim);
// NB: this is an inclusive-exclusive range
CAFFE2_API void checkDimRange(
CheckedFrom c,
const TensorGeometryArg& t,
int64_t dim_start,
int64_t dim_end);
CAFFE2_API void checkSameDim(
CheckedFrom c,
const TensorGeometryArg& t1,
const TensorGeometryArg& t2);
CAFFE2_API void checkContiguous(CheckedFrom c, const TensorGeometryArg& t);
CAFFE2_API void checkAllContiguous(CheckedFrom c, at::ArrayRef<TensorArg> ts);
CAFFE2_API void checkSize(
CheckedFrom c,
const TensorGeometryArg& t,
IntArrayRef sizes);
CAFFE2_API void checkSize(
CheckedFrom c,
const TensorGeometryArg& t,
int64_t dim,
int64_t size);
CAFFE2_API void checkNumel(
CheckedFrom c,
const TensorGeometryArg& t,
int64_t numel);
CAFFE2_API void checkSameNumel(
CheckedFrom c,
const TensorGeometryArg& t1,
const TensorGeometryArg& t2);
CAFFE2_API void checkAllSameNumel(CheckedFrom c, ArrayRef<TensorArg> tensors);
CAFFE2_API void checkScalarType(
CheckedFrom c,
const TensorArg& t,
ScalarType s);
CAFFE2_API void checkScalarTypes(
CheckedFrom c,
const TensorArg& t,
at::ArrayRef<ScalarType> l);
CAFFE2_API void checkSameGPU(
CheckedFrom c,
const TensorArg& t1,
const TensorArg& t2);
CAFFE2_API void checkAllSameGPU(CheckedFrom c, ArrayRef<TensorArg> tensors);
CAFFE2_API void checkSameType(
CheckedFrom c,
const TensorArg& t1,
const TensorArg& t2);
CAFFE2_API void checkAllSameType(CheckedFrom c, ArrayRef<TensorArg> tensors);
CAFFE2_API void checkSameSize(
CheckedFrom c,
const TensorArg& t1,
const TensorArg& t2);
CAFFE2_API void checkDefined(CheckedFrom c, const TensorArg& t);
CAFFE2_API void checkAllDefined(CheckedFrom c, at::ArrayRef<TensorArg> t);
// FixMe: does TensorArg slow things down?
CAFFE2_API void checkBackend(
CheckedFrom c,
at::ArrayRef<Tensor> t,
at::Backend backend);
CAFFE2_API void checkDeviceType(
CheckedFrom c,
at::ArrayRef<Tensor> tensors,
at::DeviceType device_type);
CAFFE2_API void checkLayout(CheckedFrom c, const Tensor& t, Layout layout);
CAFFE2_API void checkLayout(CheckedFrom c, at::ArrayRef<Tensor> tensors, at::Layout layout);
// Methods for getting data_ptr if tensor is defined
CAFFE2_API void* maybe_data_ptr(const Tensor& tensor);
CAFFE2_API void* maybe_data_ptr(const TensorArg& tensor);
// Return if the tensor geometry represented by `sizes` and `strides` is contiguous
// Although we cache is_contiguous in tensor now, this is till useful because it
// allows checking if a particular geometry is contiguous without explicitly
// constructing a tensor, e.g., when you want to choose a kernel strategy based
// on whether a subgeometry is contiguous.
CAFFE2_API bool geometry_is_contiguous(IntArrayRef sizes, IntArrayRef strides);
// Correspond to THCUNN_check_dim_size/THNN_check_dim_size
CAFFE2_API void check_dim_size(
const Tensor& tensor,
int64_t dim,
int64_t dim_size,
int64_t size);
namespace detail {
CAFFE2_API std::vector<int64_t> defaultStrides(IntArrayRef sizes);
CAFFE2_API size_t
computeStorageNbytes(IntArrayRef sizes, IntArrayRef strides, size_t itemsize);
CAFFE2_API c10::optional<std::vector<int64_t>> computeStride(
IntArrayRef oldshape,
IntArrayRef oldstride,
IntArrayRef newshape);
} // namespace detail
} // namespace at