1
+ #include < ATen/ATen.h>
1
2
#include < ATen/WrapDimUtils.h>
2
3
#include < ATen/core/op_registration/op_registration.h>
3
4
#include < nestedtensor/csrc/nested_tensor_impl.h>
5
+ #include < nestedtensor/csrc/utils/nested_node_functions.h>
4
6
#include < torch/csrc/jit/runtime/operator.h>
5
7
#include < torch/library.h>
6
- #include < ATen/ATen.h>
7
- #include < nestedtensor/csrc/utils/nested_node_functions.h>
8
8
9
9
namespace at {
10
10
@@ -146,7 +146,6 @@ at::Tensor NestedTensorImpl::to_tensor() {
146
146
return _to_tensor (get_structure ());
147
147
}
148
148
149
-
150
149
Tensor NestedTensorImpl::to_nested_tensor (c10::optional<int64_t > dim__) {
151
150
int64_t dim_ = 0 ;
152
151
if (dim__) {
@@ -160,12 +159,11 @@ Tensor NestedTensorImpl::to_nested_tensor(c10::optional<int64_t> dim__) {
160
159
for (int64_t i = 0 ; i < (dim - nested_dim ()); i++) {
161
160
unbound = _unbind_tensors (unbound);
162
161
}
163
- return at::detail::make_tensor<NestedTensorImpl>( NestedTensorImpl ( std::move (unbound) ));
162
+ return wrap_tensor_node ( std::move (unbound));
164
163
}
165
- return at::detail::make_tensor<NestedTensorImpl> (_structure);
164
+ return wrap_tensor_node ( std::move (_structure) );
166
165
}
167
166
168
-
169
167
bool is_nested_tensor_impl (const at::Tensor tensor) {
170
168
return tensor.unsafeGetTensorImpl ()->key_set ().has (at::NestedTensorKey);
171
169
}
@@ -177,22 +175,32 @@ at::NestedTensorImpl* get_nested_tensor_impl(const at::Tensor tensor) {
177
175
return static_cast <at::NestedTensorImpl*>(tensor.unsafeGetTensorImpl ());
178
176
}
179
177
180
- torch::nested_tensor::TensorNode get_nested_tensor_structure (
181
- const at::Tensor tensor) {
178
+ TensorNode get_nested_tensor_structure (const at::Tensor tensor) {
182
179
return get_nested_tensor_impl (tensor)->get_structure ();
183
180
}
184
181
185
- at::Tensor wrap_tensor_node (
186
- torch::nested_tensor::TensorNode&& result) {
182
+ at::Tensor wrap_tensor_node (TensorNode&& result) {
183
+ if (result.is_leaf ()) {
184
+ return result.payload ();
185
+ }
187
186
return at::detail::make_tensor<NestedTensorImpl>(result);
188
187
}
189
188
189
+ std::vector<at::Tensor> wrap_tensor_node (std::vector<TensorNode> input) {
190
+ std::vector<at::Tensor> result;
191
+ for (size_t i = 0 ; i < input.size (); i++) {
192
+ result.push_back (wrap_tensor_node (std::move (input[i])));
193
+ }
194
+ return result;
195
+ }
196
+
190
197
int64_t NestedTensorImpl::size (int64_t dim) const {
191
198
std::vector<c10::optional<int64_t >> size = opt_sizes ();
192
199
if (size[dim]) {
193
200
return *(size[dim]);
194
201
}
195
- throw std::runtime_error (" NestedTensor size at dim is not Tensor shape compliant." );
202
+ throw std::runtime_error (
203
+ " NestedTensor size at dim is not Tensor shape compliant." );
196
204
}
197
205
198
206
IntArrayRef NestedTensorImpl::strides () const {
@@ -208,7 +216,7 @@ Tensor NestedTensor_contiguous(const Tensor& self, MemoryFormat memory_format) {
208
216
" preserve memory format is unsupported by the contiguous operator" );
209
217
return wrap_tensor_node (
210
218
map ([](at::Tensor tensor) { return tensor.contiguous (); },
211
- get_nested_tensor_impl (self)-> get_structure ( )));
219
+ get_nested_tensor_structure (self)));
212
220
}
213
221
214
222
Tensor NestedTensor_to_tensor (Tensor tensor, c10::optional<int64_t > dim_) {
@@ -241,71 +249,38 @@ Tensor NestedTensor_to_tensor(Tensor tensor, c10::optional<int64_t> dim_) {
241
249
result.push_back (TensorNode (std::move (ci)));
242
250
}
243
251
}
244
- return at::detail::make_tensor<at::NestedTensorImpl> (TensorNode (std::move (result)));
252
+ return wrap_tensor_node (TensorNode (std::move (result)));
245
253
}
246
254
247
255
bool NestedTensor_is_pinned (const Tensor& self) {
248
256
return get_nested_tensor_impl (self)->is_pinned ();
249
257
}
250
258
251
- std::vector<at::Tensor> NestedTensor_unbind (const at::Tensor &self, int64_t dim) {
259
+ std::vector<at::Tensor> NestedTensor_unbind (
260
+ const at::Tensor& self,
261
+ int64_t dim) {
252
262
auto _data = get_nested_tensor_impl (self);
253
263
dim = at::maybe_wrap_dim (dim, _data->dim ());
254
264
auto node = _data->get_structure ();
255
- auto nested_dim = _data->nested_dim ();
256
- if (nested_dim == 1 ) {
257
- if (dim == 0 ) {
258
- std::vector<at::Tensor> result;
259
- for (const auto & child : node.unbind ()) {
260
- result.push_back (child.payload ());
261
- }
262
- return result;
263
- } else {
264
- int64_t dim_max_size = 0 ;
265
- for (const auto & child : node.unbind ()) {
266
- int64_t dim_size = child.payload ().size (dim - 1 );
267
- dim_max_size = dim_max_size > dim_size ? dim_max_size : dim_size;
268
- }
269
- std::vector<std::vector<TensorNode>> unbound;
270
- unbound.resize (dim_max_size);
271
- for (const auto & child : node.unbind ()) {
272
- std::vector<at::Tensor> unbound_tensors =
273
- at::unbind (child.payload (), dim - 1 );
274
- for (size_t i = 0 ; i < unbound_tensors.size (); i++) {
275
- unbound[i].push_back (TensorNode (std::move (unbound_tensors[i])));
276
- }
277
- }
278
- std::vector<at::Tensor> result;
279
- for (size_t i = 0 ; i < unbound.size (); i++) {
280
- TensorNode tmp = TensorNode (std::move (unbound[i]));
281
- result.push_back (at::detail::make_tensor<NestedTensorImpl>(std::move (tmp)));
282
- }
283
- return result;
284
- }
285
- }
286
- std::vector<at::Tensor> unbound_thp;
287
- for (auto child : node.unbind ()) {
288
- unbound_thp.push_back (at::detail::make_tensor<NestedTensorImpl>(std::move (child)));
289
- }
290
265
if (dim == 0 ) {
291
- return unbound_thp ;
266
+ return wrap_tensor_node (node. unbind ()) ;
292
267
}
293
268
std::vector<std::vector<TensorNode>> unbound;
294
- for (size_t i = 0 ; i < unbound_thp.size (); i++) {
295
- std::vector<at::Tensor> tmp = unbound_thp[i].unbind (dim - 1 );
269
+ for (auto child : node.unbind ()) {
270
+ std::vector<at::Tensor> tmp =
271
+ at::unbind (wrap_tensor_node (std::move (child)), dim - 1 );
296
272
for (size_t j = 0 ; j < tmp.size (); j++) {
297
- if (unbound.size () >= j ) {
273
+ if (j >= unbound.size ()) {
298
274
unbound.resize (j + 1 );
299
275
}
300
276
unbound[j].push_back (TensorNode (std::move (tmp[j])));
301
277
}
302
278
}
303
- std::vector<at::Tensor > result;
279
+ std::vector<TensorNode > result;
304
280
for (size_t i = 0 ; i < unbound.size (); i++) {
305
- result.push_back (at::detail::make_tensor<NestedTensorImpl>(
306
- TensorNode (std::move (unbound[i]))));
281
+ result.push_back (TensorNode (std::move (unbound[i])));
307
282
}
308
- return result;
283
+ return wrap_tensor_node ( result) ;
309
284
}
310
285
311
286
Tensor NestedTensor_select (const Tensor& self, int64_t dim, int64_t index) {
@@ -314,17 +289,19 @@ Tensor NestedTensor_select(const Tensor& self, int64_t dim, int64_t index) {
314
289
if (dim == 0 ) {
315
290
TORCH_CHECK_INDEX (false , " select() only supports dim == 0 for now." );
316
291
}
317
- TensorNode tn = get_nested_tensor_impl (self)->get_structure ().unbind ()[index ];
318
- return at::detail::make_tensor<NestedTensorImpl>(std::move (tn));
292
+ auto children = get_nested_tensor_structure (self).unbind ();
293
+ auto child = children[index ];
294
+ return wrap_tensor_node (std::move (child));
319
295
}
320
296
321
- Tensor NestedTensor_clone (const Tensor& src, c10::optional<c10::MemoryFormat> optional_memory_format) {
322
- auto self_impl = get_nested_tensor_impl (src);
323
- return at::detail::make_tensor<NestedTensorImpl>(
324
- map ([&optional_memory_format](Tensor a) {
325
- return at::clone (a, optional_memory_format);
326
- },
327
- self_impl->get_structure ()));
297
+ Tensor NestedTensor_clone (
298
+ const Tensor& src,
299
+ c10::optional<c10::MemoryFormat> optional_memory_format) {
300
+ return wrap_tensor_node (map (
301
+ [&optional_memory_format](Tensor a) {
302
+ return at::clone (a, optional_memory_format);
303
+ },
304
+ get_nested_tensor_structure (src)));
328
305
}
329
306
330
307
Tensor& NestedTensor_copy_ (Tensor& self, const Tensor& src, bool non_blocking) {
@@ -404,4 +381,4 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1_PreAutograd, m) {
404
381
m.impl_UNBOXED (" unbind.int" , NestedTensor_unbind);
405
382
m.impl_UNBOXED (" select.int" , NestedTensor_select);
406
383
}
407
- }
384
+ } // namespace at
0 commit comments