6
6
#include < ATen/ATen.h>
7
7
#include < nestedtensor/csrc/utils/nested_node_functions.h>
8
8
9
- namespace torch {
10
- namespace nested_tensor {
9
+ namespace at {
10
+
11
+ using namespace torch ::nested_tensor;
11
12
12
13
int64_t num_memory (c10::List<int64_t > size, c10::List<int64_t > stride) {
13
14
// 0-dim Tensors have torch.Size of .size() 0, but carry 1 memory.
@@ -47,7 +48,7 @@ std::vector<c10::optional<int64_t>> construct_size(const SizeNode& size_node) {
47
48
return result;
48
49
}
49
50
50
- std::vector<c10::optional<int64_t >> NestedTensor::sizes () const {
51
+ std::vector<c10::optional<int64_t >> NestedTensorImpl::opt_sizes () const {
51
52
return construct_size (
52
53
map ([](at::Tensor tensor) { return c10::List<int64_t >(tensor.sizes ()); },
53
54
get_structure ()));
@@ -85,11 +86,28 @@ TensorNode _unbind_tensors(TensorNode structure) {
85
86
return TensorNode (std::move (result_nodes));
86
87
}
87
88
88
- NestedTensor::NestedTensor (TensorNode&& structure)
89
- : _structure(structure),
89
+ NestedTensorImpl::NestedTensorImpl (TensorNode structure)
90
+ : TensorImpl(
91
+ c10::DispatchKeySet (NestedTensorKey),
92
+ get_first_leaf(structure) ? get_first_leaf(structure)->dtype()
93
+ : at::ones({}).dtype(),
94
+ get_first_leaf(structure) ? get_first_leaf(structure)->device()
95
+ : at::ones({}).device()),
96
+ _structure(structure),
90
97
_first_variable(
91
98
get_first_leaf (_structure) ? *get_first_leaf(_structure)
92
- : at::ones({})) {}
99
+ : at::ones({})),
100
+ _nested_size(map(
101
+ [](at::Tensor tensor) { return c10::List<int64_t >(tensor.sizes ()); },
102
+ _structure)) {
103
+ for (auto opt_int : construct_size (_nested_size)) {
104
+ if (opt_int) {
105
+ _sizes.push_back (*opt_int);
106
+ } else {
107
+ break ;
108
+ }
109
+ }
110
+ }
93
111
94
112
inline TensorNode _squeeze_nested_dim (TensorNode structure, int64_t dim) {
95
113
if (dim == 0 ) {
@@ -98,13 +116,6 @@ inline TensorNode _squeeze_nested_dim(TensorNode structure, int64_t dim) {
98
116
return TensorNode (_squeeze_nested_dim (structure, dim - 1 ));
99
117
}
100
118
101
- } // namespace nested_tensor
102
- } // namespace torch
103
-
104
- namespace at {
105
-
106
- using namespace torch ::nested_tensor;
107
-
108
119
at::Tensor _to_tensor (TensorNode node) {
109
120
// TODO: Recursive stacking is expensive.
110
121
if (node.is_leaf ()) {
@@ -123,7 +134,7 @@ at::Tensor _to_tensor(TensorNode node) {
123
134
at::Tensor NestedTensorImpl::to_tensor () {
124
135
// TODO: Not necessarily a view because of stack and reshape.
125
136
std::vector<int64_t > new_size;
126
- for (const auto & si : _data. sizes ()) {
137
+ for (const auto & si : opt_sizes ()) {
127
138
if (!si) {
128
139
// TODO: This assumes we'll extend to_tensor to also work with int64_t at
129
140
// this level.
@@ -151,7 +162,7 @@ Tensor NestedTensorImpl::to_nested_tensor(c10::optional<int64_t> dim__) {
151
162
}
152
163
return at::detail::make_tensor<NestedTensorImpl>(NestedTensorImpl (std::move (unbound)));
153
164
}
154
- return at::detail::make_tensor<NestedTensorImpl>(_data );
165
+ return at::detail::make_tensor<NestedTensorImpl>(_structure );
155
166
}
156
167
157
168
@@ -166,33 +177,18 @@ at::NestedTensorImpl* get_nested_tensor_impl(const at::Tensor tensor) {
166
177
return static_cast <at::NestedTensorImpl*>(tensor.unsafeGetTensorImpl ());
167
178
}
168
179
169
- torch::nested_tensor::NestedTensor get_nested_tensor (
170
- const at::Tensor tensor) {
171
- return get_nested_tensor_impl (tensor)->_data ;
172
- }
173
-
174
180
torch::nested_tensor::TensorNode get_nested_tensor_structure (
175
181
const at::Tensor tensor) {
176
182
return get_nested_tensor_impl (tensor)->get_structure ();
177
183
}
178
184
179
- at::Tensor wrap_nested_tensor (
180
- torch::nested_tensor::NestedTensor&& result) {
181
- return at::detail::make_tensor<NestedTensorImpl>(std::move (result));
182
- }
183
-
184
185
at::Tensor wrap_tensor_node (
185
186
torch::nested_tensor::TensorNode&& result) {
186
- return at::detail::make_tensor<NestedTensorImpl>(
187
- torch::nested_tensor::NestedTensor (std::move (result)));
188
- }
189
-
190
- IntArrayRef NestedTensorImpl::sizes () const {
191
- return IntArrayRef (_sizes);
187
+ return at::detail::make_tensor<NestedTensorImpl>(result);
192
188
}
193
189
194
190
int64_t NestedTensorImpl::size (int64_t dim) const {
195
- std::vector<c10::optional<int64_t >> size = _data. sizes ();
191
+ std::vector<c10::optional<int64_t >> size = opt_sizes ();
196
192
if (size[dim]) {
197
193
return *(size[dim]);
198
194
}
@@ -238,15 +234,14 @@ Tensor NestedTensor_to_tensor(Tensor tensor, c10::optional<int64_t> dim_) {
238
234
for (Tensor child : unbound) {
239
235
auto ci = NestedTensor_to_tensor (child, dim - 1 );
240
236
if (is_nested_tensor_impl (ci)) {
241
- auto s = get_nested_tensor (ci). get_structure ();
237
+ auto s = get_nested_tensor_impl (ci)-> get_structure ();
242
238
result.push_back (TensorNode (std::move (s)));
243
239
} else {
244
240
// TODO: If it's a NestedTensor instance get the structure
245
241
result.push_back (TensorNode (std::move (ci)));
246
242
}
247
243
}
248
- return at::detail::make_tensor<at::NestedTensorImpl>(
249
- NestedTensor (TensorNode (std::move (result))));
244
+ return at::detail::make_tensor<at::NestedTensorImpl>(TensorNode (std::move (result)));
250
245
}
251
246
252
247
bool NestedTensor_is_pinned (const Tensor& self) {
@@ -283,14 +278,14 @@ std::vector<at::Tensor> NestedTensor_unbind(const at::Tensor &self, int64_t dim)
283
278
std::vector<at::Tensor> result;
284
279
for (size_t i = 0 ; i < unbound.size (); i++) {
285
280
TensorNode tmp = TensorNode (std::move (unbound[i]));
286
- result.push_back (at::detail::make_tensor<NestedTensorImpl>(NestedTensor ( std::move (tmp) )));
281
+ result.push_back (at::detail::make_tensor<NestedTensorImpl>(std::move (tmp)));
287
282
}
288
283
return result;
289
284
}
290
285
}
291
286
std::vector<at::Tensor> unbound_thp;
292
287
for (auto child : node.unbind ()) {
293
- unbound_thp.push_back (at::detail::make_tensor<NestedTensorImpl>(NestedTensor ( std::move (child) )));
288
+ unbound_thp.push_back (at::detail::make_tensor<NestedTensorImpl>(std::move (child)));
294
289
}
295
290
if (dim == 0 ) {
296
291
return unbound_thp;
@@ -308,7 +303,7 @@ std::vector<at::Tensor> NestedTensor_unbind(const at::Tensor &self, int64_t dim)
308
303
std::vector<at::Tensor> result;
309
304
for (size_t i = 0 ; i < unbound.size (); i++) {
310
305
result.push_back (at::detail::make_tensor<NestedTensorImpl>(
311
- NestedTensor ( TensorNode (std::move (unbound[i]) ))));
306
+ TensorNode (std::move (unbound[i]))));
312
307
}
313
308
return result;
314
309
}
@@ -319,10 +314,8 @@ Tensor NestedTensor_select(const Tensor& self, int64_t dim, int64_t index) {
319
314
if (dim == 0 ) {
320
315
TORCH_CHECK_INDEX (false , " select() only supports dim == 0 for now." );
321
316
}
322
- TensorNode tn = get_nested_tensor (self).get_structure ().unbind ()[index ];
323
- torch::nested_tensor::NestedTensor nt = torch::nested_tensor::NestedTensor (
324
- std::move (tn));
325
- return at::detail::make_tensor<NestedTensorImpl>(std::move (nt));
317
+ TensorNode tn = get_nested_tensor_impl (self)->get_structure ().unbind ()[index ];
318
+ return at::detail::make_tensor<NestedTensorImpl>(std::move (tn));
326
319
}
327
320
328
321
Tensor NestedTensor_clone (const Tensor& src, c10::optional<c10::MemoryFormat> optional_memory_format) {
@@ -331,7 +324,7 @@ Tensor NestedTensor_clone(const Tensor& src, c10::optional<c10::MemoryFormat> op
331
324
map ([&optional_memory_format](Tensor a) {
332
325
return at::clone (a, optional_memory_format);
333
326
},
334
- self_impl->_data . get_structure ()));
327
+ self_impl->get_structure ()));
335
328
}
336
329
337
330
Tensor& NestedTensor_copy_ (Tensor& self, const Tensor& src, bool non_blocking) {
@@ -353,7 +346,7 @@ Tensor _NestedTensor_squeeze_(Tensor self, c10::optional<int64_t> dim_) {
353
346
// TODO: First dimension is always ignored.
354
347
// We could decide to return a Tensor if the 0th
355
348
// dimension can be squeezed.
356
- auto init_sizes = self_impl->_data . sizes ();
349
+ auto init_sizes = self_impl->opt_sizes ();
357
350
for (size_t i = 0 ; i < init_sizes.size () - 1 ; i++) {
358
351
int64_t index = init_sizes.size () - i - 1 ;
359
352
c10::optional<int64_t > s = init_sizes[index ];
@@ -366,8 +359,8 @@ Tensor _NestedTensor_squeeze_(Tensor self, c10::optional<int64_t> dim_) {
366
359
int64_t dim = at::maybe_wrap_dim (*dim_, self.dim ());
367
360
TORCH_CHECK (dim > 0 , " Cannot squeeze first dimension." );
368
361
TORCH_CHECK (
369
- ((get_nested_tensor_impl (self)->_data . sizes ()[dim]) &&
370
- ((*(get_nested_tensor_impl (self)->_data . sizes ()[dim])) == 1 )),
362
+ ((get_nested_tensor_impl (self)->opt_sizes ()[dim]) &&
363
+ ((*(get_nested_tensor_impl (self)->opt_sizes ()[dim])) == 1 )),
371
364
" Given dimension is either undefined or not a singleton." );
372
365
if (dim < get_nested_tensor_impl (self)->nested_dim ()) {
373
366
return wrap_tensor_node (
@@ -380,16 +373,12 @@ Tensor _NestedTensor_squeeze_(Tensor self, c10::optional<int64_t> dim_) {
380
373
}
381
374
382
375
Tensor& NestedTensor_squeeze_ (Tensor& self) {
383
- auto new_tensor = _NestedTensor_squeeze_ (self, c10::nullopt);
384
- auto self_impl = get_nested_tensor_impl (self);
385
- self_impl->_data = get_nested_tensor_impl (new_tensor)->_data ;
376
+ self = _NestedTensor_squeeze_ (self, c10::nullopt);
386
377
return self;
387
378
}
388
379
389
380
Tensor& NestedTensor_squeeze__dim (Tensor& self, int64_t dim) {
390
- auto new_tensor = _NestedTensor_squeeze_ (self, dim);
391
- auto self_impl = get_nested_tensor_impl (self);
392
- self_impl->_data = get_nested_tensor_impl (new_tensor)->_data ;
381
+ self = _NestedTensor_squeeze_ (self, dim);
393
382
return self;
394
383
}
395
384
@@ -415,5 +404,4 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1_PreAutograd, m) {
415
404
m.impl_UNBOXED (" unbind.int" , NestedTensor_unbind);
416
405
m.impl_UNBOXED (" select.int" , NestedTensor_select);
417
406
}
418
-
419
407
}
0 commit comments