Skip to content
This repository was archived by the owner on Nov 15, 2022. It is now read-only.

Commit a924abd

Browse files
authored
Remove torch::nested_tensor::NestedTensor (#184)
1 parent 6384059 commit a924abd

File tree

7 files changed

+90
-131
lines changed

7 files changed

+90
-131
lines changed

nestedtensor/csrc/creation.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ NestedNode<c10::IValue> py_to_nested_tensor(const py::object& py_obj) {
177177
if (THPVariable_Check(py_obj.ptr())) {
178178
at::Tensor tensor = THPVariable_Unpack(py_obj.ptr());
179179
if (is_nested_tensor_impl(tensor)) {
180-
auto tensor_data_structure = get_nested_tensor(tensor).get_structure();
180+
auto tensor_data_structure = get_nested_tensor_impl(tensor)->get_structure();
181181
return map([](at::Tensor a) { return c10::IValue(a); }, tensor_data_structure);
182182
}
183183
}
@@ -193,7 +193,7 @@ NestedNode<c10::IValue> py_to_nested_tensor(const py::object& py_obj) {
193193
}
194194
}
195195

196-
NestedTensor _as_nested_tensor(py::sequence list) {
196+
NestedTensorImpl _as_nested_tensor(py::sequence list) {
197197
NestedNode<c10::IValue> ivalue_structure = py_to_nested_tensor(list);
198198
auto fn = [](c10::IValue a, bool result) { return result && a.isTensor(); };
199199
bool all_same =
@@ -208,7 +208,7 @@ NestedTensor _as_nested_tensor(py::sequence list) {
208208
_verify_variables(*first, structure, true);
209209
}
210210
}
211-
return NestedTensor(std::move(structure));
211+
return NestedTensorImpl(std::move(structure));
212212
}
213213

214214
at::Tensor nested_tensor_impl(py::sequence list) {

nestedtensor/csrc/functions.cpp

+9-11
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ Tensor NestedTensor_max_pool2d(
5454
IntArrayRef dilation,
5555
bool ceil_mode) {
5656
auto self_impl = get_nested_tensor_impl(self);
57-
auto nt = self_impl->_data;
5857
auto tensor_node = get_nested_tensor_structure(self);
5958

6059
if (is_tensor_shape(self)) {
@@ -66,8 +65,7 @@ Tensor NestedTensor_max_pool2d(
6665
auto res = at::max_pool2d(
6766
at::stack(tensors), kernel_size, stride, padding, dilation, ceil_mode);
6867

69-
return NestedTensorImpl(
70-
torch::nested_tensor::NestedTensor(std::move(res)))
68+
return NestedTensorImpl(std::move(res))
7169
.to_nested_tensor(self_impl->nested_dim() - 1);
7270
}
7371

@@ -198,16 +196,16 @@ Tensor NestedTensor_layer_norm(
198196
TORCH_CHECK(
199197
normalized_shape.size() == 1,
200198
"Currently only singleton tuples of integers supported for layer_norm.");
201-
auto input_data = get_nested_tensor(input);
199+
auto input_data = get_nested_tensor_impl(input);
202200
TORCH_CHECK(
203-
input_data.sizes()[input.dim() - 1],
201+
input_data->opt_sizes()[input.dim() - 1],
204202
"Cannot normalize across irregular dimension ",
205203
std::to_string(input.dim() - 1));
206204
return wrap_tensor_node(map(
207205
[normalized_shape, &weight, &bias, eps](const at::Tensor t) {
208206
return at::layer_norm(t, normalized_shape, weight, bias, eps, true);
209207
},
210-
input_data.get_structure()));
208+
input_data->get_structure()));
211209
}
212210

213211
Tensor& NestedTensor_add_(Tensor& self, const Tensor& other, Scalar alpha) {
@@ -225,7 +223,7 @@ Tensor& NestedTensor_add_(Tensor& self, const Tensor& other, Scalar alpha) {
225223
}
226224

227225
Tensor NestedTensor_all(const Tensor& self) {
228-
auto self_impl = get_nested_tensor_impl(self)->_data;
226+
auto self_impl = get_nested_tensor_impl(self);
229227
if (self.numel() == 0) {
230228
// XXX: self.options doesn't work here because
231229
// we don't want a Tensor backed by a NestedTensor
@@ -235,7 +233,7 @@ Tensor NestedTensor_all(const Tensor& self) {
235233
}
236234
auto map_all = flatten(
237235
map([](at::Tensor tensor) { return tensor.all(); },
238-
self_impl.get_structure()));
236+
self_impl->get_structure()));
239237
at::Tensor gathered = at::empty(
240238
{static_cast<int64_t>(map_all.size())}, at::kBool); //, self.options());
241239
for (size_t i = 0; i < map_all.size(); i++) {
@@ -245,7 +243,7 @@ Tensor NestedTensor_all(const Tensor& self) {
245243
}
246244

247245
Tensor NestedTensor_any(const Tensor& self) {
248-
auto self_impl = get_nested_tensor_impl(self)->_data;
246+
auto self_impl = get_nested_tensor_impl(self);
249247
if (self.numel() == 0) {
250248
// XXX: self.options doesn't work here because
251249
// we don't want a Tensor backed by a NestedTensor
@@ -255,7 +253,7 @@ Tensor NestedTensor_any(const Tensor& self) {
255253
}
256254
auto map_any = flatten(
257255
map([](at::Tensor tensor) { return tensor.any(); },
258-
self_impl.get_structure()));
256+
self_impl->get_structure()));
259257
at::Tensor gathered = at::empty(
260258
{static_cast<int64_t>(map_any.size())}, at::kBool); //, self.options());
261259
for (size_t i = 0; i < map_any.size(); i++) {
@@ -271,7 +269,7 @@ Tensor NestedTensor__log_softmax(
271269
auto self_impl = get_nested_tensor_impl(input_);
272270
return at::detail::make_tensor<NestedTensorImpl>(
273271
map([&](Tensor a) { return at::_log_softmax(a, dim_, half_to_float); },
274-
self_impl->_data.get_structure()));
272+
self_impl->get_structure()));
275273
}
276274

277275
Tensor NestedTensor_matmul(const Tensor& self, const Tensor& other) {

nestedtensor/csrc/nested_tensor_impl.cpp

+41-53
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
#include <ATen/ATen.h>
77
#include <nestedtensor/csrc/utils/nested_node_functions.h>
88

9-
namespace torch {
10-
namespace nested_tensor {
9+
namespace at {
10+
11+
using namespace torch::nested_tensor;
1112

1213
int64_t num_memory(c10::List<int64_t> size, c10::List<int64_t> stride) {
1314
// 0-dim Tensors have torch.Size of .size() 0, but carry 1 memory.
@@ -47,7 +48,7 @@ std::vector<c10::optional<int64_t>> construct_size(const SizeNode& size_node) {
4748
return result;
4849
}
4950

50-
std::vector<c10::optional<int64_t>> NestedTensor::sizes() const {
51+
std::vector<c10::optional<int64_t>> NestedTensorImpl::opt_sizes() const {
5152
return construct_size(
5253
map([](at::Tensor tensor) { return c10::List<int64_t>(tensor.sizes()); },
5354
get_structure()));
@@ -85,11 +86,28 @@ TensorNode _unbind_tensors(TensorNode structure) {
8586
return TensorNode(std::move(result_nodes));
8687
}
8788

88-
NestedTensor::NestedTensor(TensorNode&& structure)
89-
: _structure(structure),
89+
NestedTensorImpl::NestedTensorImpl(TensorNode structure)
90+
: TensorImpl(
91+
c10::DispatchKeySet(NestedTensorKey),
92+
get_first_leaf(structure) ? get_first_leaf(structure)->dtype()
93+
: at::ones({}).dtype(),
94+
get_first_leaf(structure) ? get_first_leaf(structure)->device()
95+
: at::ones({}).device()),
96+
_structure(structure),
9097
_first_variable(
9198
get_first_leaf(_structure) ? *get_first_leaf(_structure)
92-
: at::ones({})) {}
99+
: at::ones({})),
100+
_nested_size(map(
101+
[](at::Tensor tensor) { return c10::List<int64_t>(tensor.sizes()); },
102+
_structure)) {
103+
for (auto opt_int : construct_size(_nested_size)) {
104+
if (opt_int) {
105+
_sizes.push_back(*opt_int);
106+
} else {
107+
break;
108+
}
109+
}
110+
}
93111

94112
inline TensorNode _squeeze_nested_dim(TensorNode structure, int64_t dim) {
95113
if (dim == 0) {
@@ -98,13 +116,6 @@ inline TensorNode _squeeze_nested_dim(TensorNode structure, int64_t dim) {
98116
return TensorNode(_squeeze_nested_dim(structure, dim - 1));
99117
}
100118

101-
} // namespace nested_tensor
102-
} // namespace torch
103-
104-
namespace at {
105-
106-
using namespace torch::nested_tensor;
107-
108119
at::Tensor _to_tensor(TensorNode node) {
109120
// TODO: Recursive stacking is expensive.
110121
if (node.is_leaf()) {
@@ -123,7 +134,7 @@ at::Tensor _to_tensor(TensorNode node) {
123134
at::Tensor NestedTensorImpl::to_tensor() {
124135
// TODO: Not necessarily a view because of stack and reshape.
125136
std::vector<int64_t> new_size;
126-
for (const auto& si : _data.sizes()) {
137+
for (const auto& si : opt_sizes()) {
127138
if (!si) {
128139
// TODO: This assumes we'll extend to_tensor to also work with int64_t at
129140
// this level.
@@ -151,7 +162,7 @@ Tensor NestedTensorImpl::to_nested_tensor(c10::optional<int64_t> dim__) {
151162
}
152163
return at::detail::make_tensor<NestedTensorImpl>(NestedTensorImpl(std::move(unbound)));
153164
}
154-
return at::detail::make_tensor<NestedTensorImpl>(_data);
165+
return at::detail::make_tensor<NestedTensorImpl>(_structure);
155166
}
156167

157168

@@ -166,33 +177,18 @@ at::NestedTensorImpl* get_nested_tensor_impl(const at::Tensor tensor) {
166177
return static_cast<at::NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
167178
}
168179

169-
torch::nested_tensor::NestedTensor get_nested_tensor(
170-
const at::Tensor tensor) {
171-
return get_nested_tensor_impl(tensor)->_data;
172-
}
173-
174180
torch::nested_tensor::TensorNode get_nested_tensor_structure(
175181
const at::Tensor tensor) {
176182
return get_nested_tensor_impl(tensor)->get_structure();
177183
}
178184

179-
at::Tensor wrap_nested_tensor(
180-
torch::nested_tensor::NestedTensor&& result) {
181-
return at::detail::make_tensor<NestedTensorImpl>(std::move(result));
182-
}
183-
184185
at::Tensor wrap_tensor_node(
185186
torch::nested_tensor::TensorNode&& result) {
186-
return at::detail::make_tensor<NestedTensorImpl>(
187-
torch::nested_tensor::NestedTensor(std::move(result)));
188-
}
189-
190-
IntArrayRef NestedTensorImpl::sizes() const {
191-
return IntArrayRef(_sizes);
187+
return at::detail::make_tensor<NestedTensorImpl>(result);
192188
}
193189

194190
int64_t NestedTensorImpl::size(int64_t dim) const {
195-
std::vector<c10::optional<int64_t>> size = _data.sizes();
191+
std::vector<c10::optional<int64_t>> size = opt_sizes();
196192
if (size[dim]) {
197193
return *(size[dim]);
198194
}
@@ -238,15 +234,14 @@ Tensor NestedTensor_to_tensor(Tensor tensor, c10::optional<int64_t> dim_) {
238234
for (Tensor child : unbound) {
239235
auto ci = NestedTensor_to_tensor(child, dim - 1);
240236
if (is_nested_tensor_impl(ci)) {
241-
auto s = get_nested_tensor(ci).get_structure();
237+
auto s = get_nested_tensor_impl(ci)->get_structure();
242238
result.push_back(TensorNode(std::move(s)));
243239
} else {
244240
// TODO: If it's a NestedTensor instance get the structure
245241
result.push_back(TensorNode(std::move(ci)));
246242
}
247243
}
248-
return at::detail::make_tensor<at::NestedTensorImpl>(
249-
NestedTensor(TensorNode(std::move(result))));
244+
return at::detail::make_tensor<at::NestedTensorImpl>(TensorNode(std::move(result)));
250245
}
251246

252247
bool NestedTensor_is_pinned(const Tensor& self) {
@@ -283,14 +278,14 @@ std::vector<at::Tensor> NestedTensor_unbind(const at::Tensor &self, int64_t dim)
283278
std::vector<at::Tensor> result;
284279
for (size_t i = 0; i < unbound.size(); i++) {
285280
TensorNode tmp = TensorNode(std::move(unbound[i]));
286-
result.push_back(at::detail::make_tensor<NestedTensorImpl>(NestedTensor(std::move(tmp))));
281+
result.push_back(at::detail::make_tensor<NestedTensorImpl>(std::move(tmp)));
287282
}
288283
return result;
289284
}
290285
}
291286
std::vector<at::Tensor> unbound_thp;
292287
for (auto child : node.unbind()) {
293-
unbound_thp.push_back(at::detail::make_tensor<NestedTensorImpl>(NestedTensor(std::move(child))));
288+
unbound_thp.push_back(at::detail::make_tensor<NestedTensorImpl>(std::move(child)));
294289
}
295290
if (dim == 0) {
296291
return unbound_thp;
@@ -308,7 +303,7 @@ std::vector<at::Tensor> NestedTensor_unbind(const at::Tensor &self, int64_t dim)
308303
std::vector<at::Tensor> result;
309304
for (size_t i = 0; i < unbound.size(); i++) {
310305
result.push_back(at::detail::make_tensor<NestedTensorImpl>(
311-
NestedTensor(TensorNode(std::move(unbound[i])))));
306+
TensorNode(std::move(unbound[i]))));
312307
}
313308
return result;
314309
}
@@ -319,10 +314,8 @@ Tensor NestedTensor_select(const Tensor& self, int64_t dim, int64_t index) {
319314
if (dim == 0) {
320315
TORCH_CHECK_INDEX(false, "select() only supports dim == 0 for now.");
321316
}
322-
TensorNode tn = get_nested_tensor(self).get_structure().unbind()[index];
323-
torch::nested_tensor::NestedTensor nt = torch::nested_tensor::NestedTensor(
324-
std::move(tn));
325-
return at::detail::make_tensor<NestedTensorImpl>(std::move(nt));
317+
TensorNode tn = get_nested_tensor_impl(self)->get_structure().unbind()[index];
318+
return at::detail::make_tensor<NestedTensorImpl>(std::move(tn));
326319
}
327320

328321
Tensor NestedTensor_clone(const Tensor& src, c10::optional<c10::MemoryFormat> optional_memory_format) {
@@ -331,7 +324,7 @@ Tensor NestedTensor_clone(const Tensor& src, c10::optional<c10::MemoryFormat> op
331324
map([&optional_memory_format](Tensor a) {
332325
return at::clone(a, optional_memory_format);
333326
},
334-
self_impl->_data.get_structure()));
327+
self_impl->get_structure()));
335328
}
336329

337330
Tensor& NestedTensor_copy_(Tensor& self, const Tensor& src, bool non_blocking) {
@@ -353,7 +346,7 @@ Tensor _NestedTensor_squeeze_(Tensor self, c10::optional<int64_t> dim_) {
353346
// TODO: First dimension is always ignored.
354347
// We could decide to return a Tensor if the 0th
355348
// dimension can be squeezed.
356-
auto init_sizes = self_impl->_data.sizes();
349+
auto init_sizes = self_impl->opt_sizes();
357350
for (size_t i = 0; i < init_sizes.size() - 1; i++) {
358351
int64_t index = init_sizes.size() - i - 1;
359352
c10::optional<int64_t> s = init_sizes[index];
@@ -366,8 +359,8 @@ Tensor _NestedTensor_squeeze_(Tensor self, c10::optional<int64_t> dim_) {
366359
int64_t dim = at::maybe_wrap_dim(*dim_, self.dim());
367360
TORCH_CHECK(dim > 0, "Cannot squeeze first dimension.");
368361
TORCH_CHECK(
369-
((get_nested_tensor_impl(self)->_data.sizes()[dim]) &&
370-
((*(get_nested_tensor_impl(self)->_data.sizes()[dim])) == 1)),
362+
((get_nested_tensor_impl(self)->opt_sizes()[dim]) &&
363+
((*(get_nested_tensor_impl(self)->opt_sizes()[dim])) == 1)),
371364
"Given dimension is either undefined or not a singleton.");
372365
if (dim < get_nested_tensor_impl(self)->nested_dim()) {
373366
return wrap_tensor_node(
@@ -380,16 +373,12 @@ Tensor _NestedTensor_squeeze_(Tensor self, c10::optional<int64_t> dim_) {
380373
}
381374

382375
Tensor& NestedTensor_squeeze_(Tensor& self) {
383-
auto new_tensor = _NestedTensor_squeeze_(self, c10::nullopt);
384-
auto self_impl = get_nested_tensor_impl(self);
385-
self_impl->_data = get_nested_tensor_impl(new_tensor)->_data;
376+
self = _NestedTensor_squeeze_(self, c10::nullopt);
386377
return self;
387378
}
388379

389380
Tensor& NestedTensor_squeeze__dim(Tensor& self, int64_t dim) {
390-
auto new_tensor = _NestedTensor_squeeze_(self, dim);
391-
auto self_impl = get_nested_tensor_impl(self);
392-
self_impl->_data = get_nested_tensor_impl(new_tensor)->_data;
381+
self = _NestedTensor_squeeze_(self, dim);
393382
return self;
394383
}
395384

@@ -415,5 +404,4 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1_PreAutograd, m) {
415404
m.impl_UNBOXED("unbind.int", NestedTensor_unbind);
416405
m.impl_UNBOXED("select.int", NestedTensor_select);
417406
}
418-
419407
}

0 commit comments

Comments
 (0)