-
Notifications
You must be signed in to change notification settings - Fork 1
/
NamedTensor.cpp
143 lines (121 loc) · 4.14 KB
/
NamedTensor.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#include <ATen/core/NamedTensor.h>
#include <ATen/core/EnableNamedTensor.h>
#ifdef BUILD_NAMEDTENSOR
#include <ATen/core/Tensor.h>
#include <c10/util/C++17.h>
namespace at {
bool NamedTensorMeta::has_names() const {
return !std::all_of(
names_.begin(), names_.end(), [](const Dimname& n) {
return n.type() == NameType::WILDCARD;
});
}
thread_local bool NamesMode_enabled = true;
bool NamesMode::is_enabled() {
return NamesMode_enabled;
}
void NamesMode::set_enabled(bool enabled) {
NamesMode_enabled = enabled;
}
Tensor& internal_set_names_inplace(Tensor& tensor, optional<DimnameList> names) {
impl::internal_set_names_inplace(tensor.unsafeGetTensorImpl(), names);
return tensor;
}
Tensor& internal_set_names_inplace(Tensor& tensor, std::vector<Dimname>&& names, bool validate_names) {
impl::internal_set_names_inplace(tensor.unsafeGetTensorImpl(), std::move(names), validate_names);
return tensor;
}
DimnameList default_names(size_t len) {
static std::vector<Dimname> all_unnamed(kMaxNamedTensorDim, Dimname::wildcard());
TORCH_INTERNAL_ASSERT(
len <= kMaxNamedTensorDim,
"Only tensors with up to ", kMaxNamedTensorDim, " are supported.");
return DimnameList(&all_unnamed.front(), len);
}
void check_names_valid_for(const Tensor& tensor, DimnameList names) {
return impl::check_names_valid_for(tensor.unsafeGetTensorImpl(), names);
}
namespace impl {
static void check_unique_names(DimnameList names) {
// Strategy: Compare each element with the ones that come after it.
// Although this is O(N^2), in practice N is small (no more than 25).
for (auto it = names.begin(); it != names.end(); ++it) {
if (it->isWildcard()) continue;
auto dup = std::find(it + 1, names.end(), *it);
while (dup != names.end()) {
TORCH_CHECK(false,
"Cannot construct a tensor with duplicate names. Got names: ",
names, ".");
}
}
}
static NamedTensorMeta* get_named_tensor_meta(TensorImpl* impl) {
if (!NamesMode::is_enabled()) {
return nullptr;
}
return static_cast<NamedTensorMeta*>(impl->named_tensor_meta());
}
static const NamedTensorMeta* get_named_tensor_meta(const TensorImpl* impl) {
if (!NamesMode::is_enabled()) {
return nullptr;
}
return static_cast<const NamedTensorMeta*>(impl->named_tensor_meta());
}
void check_names_valid_for(TensorImpl* impl, DimnameList names) {
auto ndim = impl->dim();
TORCH_CHECK(
ndim <= kMaxNamedTensorDim,
"Named tensors only support up to ", kMaxNamedTensorDim, " dims: "
"Attempted to create a tensor with dim ", ndim, " with names ", names);
TORCH_CHECK(ndim == names.size(),
"Number of names (", names.size(), ") and "
"number of dimensions in tensor (", ndim, ") ",
"do not match. Attempted to create a tensor with names ", names);
check_unique_names(names);
}
void internal_set_names_inplace(TensorImpl* impl, optional<DimnameList> names) {
if (!names) {
impl->set_named_tensor_meta(nullptr);
return;
}
check_names_valid_for(impl, *names);
auto* meta = get_named_tensor_meta(impl);
if (meta == nullptr) {
impl->set_named_tensor_meta(c10::guts::make_unique<NamedTensorMeta>(*names));
} else {
meta->set_names(*names);
}
}
void internal_set_names_inplace(TensorImpl* impl, std::vector<Dimname>&& names, bool validate_names) {
if (validate_names) {
check_names_valid_for(impl, names);
}
auto* meta = get_named_tensor_meta(impl);
if (meta == nullptr) {
impl->set_named_tensor_meta(c10::guts::make_unique<NamedTensorMeta>(names));
} else {
meta->set_names(names);
}
}
optional<DimnameList> get_opt_names(const TensorImpl* impl) {
const auto* meta = get_named_tensor_meta(impl);
if (meta == nullptr) {
return nullopt;
} else {
return meta->names();
}
}
DimnameList get_names(const TensorImpl* impl) {
auto maybe_names = get_opt_names(impl);
if (maybe_names) {
return *maybe_names;
}
return default_names(impl->dim());
}
bool has_names(const TensorImpl* impl) {
const auto* named_tensor_meta = get_named_tensor_meta(impl);
return named_tensor_meta != nullptr && named_tensor_meta->has_names();
}
} // namespace impl
} // namespace at
#endif