Skip to content
This repository was archived by the owner on Nov 15, 2022. It is now read-only.

Commit ca5cf05

Browse files
authored
Clean up unbind (#186)
1 parent 17fbe15 commit ca5cf05

File tree

7 files changed

+169
-174
lines changed

7 files changed

+169
-174
lines changed

nestedtensor/csrc/UnaryOps.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ Tensor& NestedTensor_clamp_(Tensor& self, optional<Scalar> min, optional<Scalar>
5151
}
5252

5353
Tensor NestedTensor_clamp(const Tensor& self, optional<Scalar> min, optional<Scalar> max) {
54-
return at::detail::make_tensor<NestedTensorImpl>(
54+
return wrap_tensor_node(
5555
map([min, max](at::Tensor tensor) { return at::clamp(tensor, min, max); },
5656
get_nested_tensor_structure(self)));
5757
}
@@ -77,7 +77,7 @@ Tensor& NestedTensor_clamp_min_(Tensor& self, Scalar min) {
7777
}
7878

7979
Tensor NestedTensor_clamp_min(const Tensor& self, Scalar min) {
80-
return at::detail::make_tensor<NestedTensorImpl>(
80+
return wrap_tensor_node(
8181
map([min](at::Tensor tensor) { return at::clamp_min(tensor, min); },
8282
get_nested_tensor_structure(self)));
8383
}
@@ -99,7 +99,7 @@ Tensor& NestedTensor_clamp_max_(Tensor& self, Scalar min) {
9999
}
100100

101101
Tensor NestedTensor_clamp_max(const Tensor& self, Scalar min) {
102-
return at::detail::make_tensor<NestedTensorImpl>(
102+
return wrap_tensor_node(
103103
map([min](at::Tensor tensor) { return at::clamp_max(tensor, min); },
104104
get_nested_tensor_structure(self)));
105105
}
@@ -121,7 +121,7 @@ Tensor& NestedTensor_mvlgamma_(Tensor& self, int64_t p) {
121121
}
122122

123123
Tensor NestedTensor_mvlgamma(const Tensor& self, int64_t p) {
124-
return at::detail::make_tensor<NestedTensorImpl>(
124+
return wrap_tensor_node(
125125
map([p](at::Tensor tensor) { return at::mvlgamma(tensor, p); },
126126
get_nested_tensor_structure(self)));
127127
}

nestedtensor/csrc/functions.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ Tensor NestedTensor__log_softmax(
267267
const int64_t dim_,
268268
const bool half_to_float) {
269269
auto self_impl = get_nested_tensor_impl(input_);
270-
return at::detail::make_tensor<NestedTensorImpl>(
270+
return wrap_tensor_node(
271271
map([&](Tensor a) { return at::_log_softmax(a, dim_, half_to_float); },
272272
self_impl->get_structure()));
273273
}

nestedtensor/csrc/nested_tensor_impl.cpp

+44-67
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
#include <ATen/ATen.h>
12
#include <ATen/WrapDimUtils.h>
23
#include <ATen/core/op_registration/op_registration.h>
34
#include <nestedtensor/csrc/nested_tensor_impl.h>
5+
#include <nestedtensor/csrc/utils/nested_node_functions.h>
46
#include <torch/csrc/jit/runtime/operator.h>
57
#include <torch/library.h>
6-
#include <ATen/ATen.h>
7-
#include <nestedtensor/csrc/utils/nested_node_functions.h>
88

99
namespace at {
1010

@@ -146,7 +146,6 @@ at::Tensor NestedTensorImpl::to_tensor() {
146146
return _to_tensor(get_structure());
147147
}
148148

149-
150149
Tensor NestedTensorImpl::to_nested_tensor(c10::optional<int64_t> dim__) {
151150
int64_t dim_ = 0;
152151
if (dim__) {
@@ -160,12 +159,11 @@ Tensor NestedTensorImpl::to_nested_tensor(c10::optional<int64_t> dim__) {
160159
for (int64_t i = 0; i < (dim - nested_dim()); i++) {
161160
unbound = _unbind_tensors(unbound);
162161
}
163-
return at::detail::make_tensor<NestedTensorImpl>(NestedTensorImpl(std::move(unbound)));
162+
return wrap_tensor_node(std::move(unbound));
164163
}
165-
return at::detail::make_tensor<NestedTensorImpl>(_structure);
164+
return wrap_tensor_node(std::move(_structure));
166165
}
167166

168-
169167
bool is_nested_tensor_impl(const at::Tensor tensor) {
170168
return tensor.unsafeGetTensorImpl()->key_set().has(at::NestedTensorKey);
171169
}
@@ -177,22 +175,32 @@ at::NestedTensorImpl* get_nested_tensor_impl(const at::Tensor tensor) {
177175
return static_cast<at::NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
178176
}
179177

180-
torch::nested_tensor::TensorNode get_nested_tensor_structure(
181-
const at::Tensor tensor) {
178+
TensorNode get_nested_tensor_structure(const at::Tensor tensor) {
182179
return get_nested_tensor_impl(tensor)->get_structure();
183180
}
184181

185-
at::Tensor wrap_tensor_node(
186-
torch::nested_tensor::TensorNode&& result) {
182+
at::Tensor wrap_tensor_node(TensorNode&& result) {
183+
if (result.is_leaf()) {
184+
return result.payload();
185+
}
187186
return at::detail::make_tensor<NestedTensorImpl>(result);
188187
}
189188

189+
std::vector<at::Tensor> wrap_tensor_node(std::vector<TensorNode> input) {
190+
std::vector<at::Tensor> result;
191+
for (size_t i = 0; i < input.size(); i++) {
192+
result.push_back(wrap_tensor_node(std::move(input[i])));
193+
}
194+
return result;
195+
}
196+
190197
int64_t NestedTensorImpl::size(int64_t dim) const {
191198
std::vector<c10::optional<int64_t>> size = opt_sizes();
192199
if (size[dim]) {
193200
return *(size[dim]);
194201
}
195-
throw std::runtime_error("NestedTensor size at dim is not Tensor shape compliant.");
202+
throw std::runtime_error(
203+
"NestedTensor size at dim is not Tensor shape compliant.");
196204
}
197205

198206
IntArrayRef NestedTensorImpl::strides() const {
@@ -208,7 +216,7 @@ Tensor NestedTensor_contiguous(const Tensor& self, MemoryFormat memory_format) {
208216
"preserve memory format is unsupported by the contiguous operator");
209217
return wrap_tensor_node(
210218
map([](at::Tensor tensor) { return tensor.contiguous(); },
211-
get_nested_tensor_impl(self)->get_structure()));
219+
get_nested_tensor_structure(self)));
212220
}
213221

214222
Tensor NestedTensor_to_tensor(Tensor tensor, c10::optional<int64_t> dim_) {
@@ -241,71 +249,38 @@ Tensor NestedTensor_to_tensor(Tensor tensor, c10::optional<int64_t> dim_) {
241249
result.push_back(TensorNode(std::move(ci)));
242250
}
243251
}
244-
return at::detail::make_tensor<at::NestedTensorImpl>(TensorNode(std::move(result)));
252+
return wrap_tensor_node(TensorNode(std::move(result)));
245253
}
246254

247255
bool NestedTensor_is_pinned(const Tensor& self) {
248256
return get_nested_tensor_impl(self)->is_pinned();
249257
}
250258

251-
std::vector<at::Tensor> NestedTensor_unbind(const at::Tensor &self, int64_t dim) {
259+
std::vector<at::Tensor> NestedTensor_unbind(
260+
const at::Tensor& self,
261+
int64_t dim) {
252262
auto _data = get_nested_tensor_impl(self);
253263
dim = at::maybe_wrap_dim(dim, _data->dim());
254264
auto node = _data->get_structure();
255-
auto nested_dim = _data->nested_dim();
256-
if (nested_dim == 1) {
257-
if (dim == 0) {
258-
std::vector<at::Tensor> result;
259-
for (const auto& child : node.unbind()) {
260-
result.push_back(child.payload());
261-
}
262-
return result;
263-
} else {
264-
int64_t dim_max_size = 0;
265-
for (const auto& child : node.unbind()) {
266-
int64_t dim_size = child.payload().size(dim - 1);
267-
dim_max_size = dim_max_size > dim_size ? dim_max_size : dim_size;
268-
}
269-
std::vector<std::vector<TensorNode>> unbound;
270-
unbound.resize(dim_max_size);
271-
for (const auto& child : node.unbind()) {
272-
std::vector<at::Tensor> unbound_tensors =
273-
at::unbind(child.payload(), dim - 1);
274-
for (size_t i = 0; i < unbound_tensors.size(); i++) {
275-
unbound[i].push_back(TensorNode(std::move(unbound_tensors[i])));
276-
}
277-
}
278-
std::vector<at::Tensor> result;
279-
for (size_t i = 0; i < unbound.size(); i++) {
280-
TensorNode tmp = TensorNode(std::move(unbound[i]));
281-
result.push_back(at::detail::make_tensor<NestedTensorImpl>(std::move(tmp)));
282-
}
283-
return result;
284-
}
285-
}
286-
std::vector<at::Tensor> unbound_thp;
287-
for (auto child : node.unbind()) {
288-
unbound_thp.push_back(at::detail::make_tensor<NestedTensorImpl>(std::move(child)));
289-
}
290265
if (dim == 0) {
291-
return unbound_thp;
266+
return wrap_tensor_node(node.unbind());
292267
}
293268
std::vector<std::vector<TensorNode>> unbound;
294-
for (size_t i = 0; i < unbound_thp.size(); i++) {
295-
std::vector<at::Tensor> tmp = unbound_thp[i].unbind(dim - 1);
269+
for (auto child : node.unbind()) {
270+
std::vector<at::Tensor> tmp =
271+
at::unbind(wrap_tensor_node(std::move(child)), dim - 1);
296272
for (size_t j = 0; j < tmp.size(); j++) {
297-
if (unbound.size() >= j) {
273+
if (j >= unbound.size()) {
298274
unbound.resize(j + 1);
299275
}
300276
unbound[j].push_back(TensorNode(std::move(tmp[j])));
301277
}
302278
}
303-
std::vector<at::Tensor> result;
279+
std::vector<TensorNode> result;
304280
for (size_t i = 0; i < unbound.size(); i++) {
305-
result.push_back(at::detail::make_tensor<NestedTensorImpl>(
306-
TensorNode(std::move(unbound[i]))));
281+
result.push_back(TensorNode(std::move(unbound[i])));
307282
}
308-
return result;
283+
return wrap_tensor_node(result);
309284
}
310285

311286
Tensor NestedTensor_select(const Tensor& self, int64_t dim, int64_t index) {
@@ -314,17 +289,19 @@ Tensor NestedTensor_select(const Tensor& self, int64_t dim, int64_t index) {
314289
if (dim == 0) {
315290
TORCH_CHECK_INDEX(false, "select() only supports dim == 0 for now.");
316291
}
317-
TensorNode tn = get_nested_tensor_impl(self)->get_structure().unbind()[index];
318-
return at::detail::make_tensor<NestedTensorImpl>(std::move(tn));
292+
auto children = get_nested_tensor_structure(self).unbind();
293+
auto child = children[index];
294+
return wrap_tensor_node(std::move(child));
319295
}
320296

321-
Tensor NestedTensor_clone(const Tensor& src, c10::optional<c10::MemoryFormat> optional_memory_format) {
322-
auto self_impl = get_nested_tensor_impl(src);
323-
return at::detail::make_tensor<NestedTensorImpl>(
324-
map([&optional_memory_format](Tensor a) {
325-
return at::clone(a, optional_memory_format);
326-
},
327-
self_impl->get_structure()));
297+
Tensor NestedTensor_clone(
298+
const Tensor& src,
299+
c10::optional<c10::MemoryFormat> optional_memory_format) {
300+
return wrap_tensor_node(map(
301+
[&optional_memory_format](Tensor a) {
302+
return at::clone(a, optional_memory_format);
303+
},
304+
get_nested_tensor_structure(src)));
328305
}
329306

330307
Tensor& NestedTensor_copy_(Tensor& self, const Tensor& src, bool non_blocking) {
@@ -404,4 +381,4 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1_PreAutograd, m) {
404381
m.impl_UNBOXED("unbind.int", NestedTensor_unbind);
405382
m.impl_UNBOXED("select.int", NestedTensor_select);
406383
}
407-
}
384+
} // namespace at

nestedtensor/csrc/nested_tensor_impl.h

+12-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
2-
#include <nestedtensor/csrc/utils/nested_node.h>
32
#include <ATen/ATen.h>
3+
#include <nestedtensor/csrc/utils/nested_node.h>
44

55
namespace torch {
66
namespace nested_tensor {
@@ -23,9 +23,12 @@ struct NestedTensorImpl;
2323

2424
bool is_nested_tensor_impl(const at::Tensor tensor);
2525
at::NestedTensorImpl* get_nested_tensor_impl(const at::Tensor tensor);
26-
torch::nested_tensor::TensorNode get_nested_tensor_structure(const at::Tensor tensor);
26+
torch::nested_tensor::TensorNode get_nested_tensor_structure(
27+
const at::Tensor tensor);
2728

28-
at::Tensor wrap_tensor_node(TensorNode&& result);
29+
at::Tensor wrap_tensor_node(NestedTensorImpl);
30+
at::Tensor wrap_tensor_node(TensorNode&&);
31+
std::vector<at::Tensor> wrap_tensor_node(std::vector<TensorNode>);
2932

3033
struct NestedTensorImpl : public c10::TensorImpl {
3134
explicit NestedTensorImpl(TensorNode structure);
@@ -39,8 +42,7 @@ struct NestedTensorImpl : public c10::TensorImpl {
3942
};
4043
return reduce<decltype(fn), int64_t, at::Tensor>(get_structure(), fn, 0);
4144
}
42-
bool is_contiguous(
43-
at::MemoryFormat memory_format) const override {
45+
bool is_contiguous(at::MemoryFormat memory_format) const override {
4446
// NOTE: The Tensors themselves might not be contiguous even if there is a
4547
// buffer. For this to be contiguous not only the individuals Tensors have
4648
// to be but also the buffer.
@@ -74,8 +76,7 @@ struct NestedTensorImpl : public c10::TensorImpl {
7476
throw std::runtime_error("Grad is undefined");
7577
}
7678
return wrap_tensor_node(
77-
map([](at::Tensor tensor) {
78-
return tensor.grad(); }, get_structure()));
79+
map([](at::Tensor tensor) { return tensor.grad(); }, get_structure()));
7980
}
8081
Tensor requires_grad_(bool requires_grad) {
8182
apply(
@@ -138,7 +139,6 @@ struct NestedTensorImpl : public c10::TensorImpl {
138139
std::vector<int64_t> _sizes;
139140
};
140141

141-
142142
inline bool is_tensor_shape(const at::Tensor tensor) {
143143
auto nt = get_nested_tensor_impl(tensor);
144144
for (const auto& size : nt->opt_sizes()) {
@@ -151,12 +151,14 @@ inline bool is_tensor_shape(const at::Tensor tensor) {
151151

152152
Tensor NestedTensor_to_tensor(Tensor tensor, c10::optional<int64_t> dim_);
153153

154-
inline std::ostream& operator<<(std::ostream& out, const NestedTensorImpl& batch_tensor) {
154+
inline std::ostream& operator<<(
155+
std::ostream& out,
156+
const NestedTensorImpl& batch_tensor) {
155157
auto node = batch_tensor.get_structure();
156158
out << "NESTED_TENSOR";
157159
apply([&out](at::Tensor tensor) { out << tensor << std::endl; }, node);
158160
out << std::endl;
159161
return out;
160162
}
161163

162-
}
164+
} // namespace at

nestedtensor/nested/nested.py

+4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
import itertools
1212

1313
def _wrap_result(result):
14+
if isinstance(result, list):
15+
return list(_wrap_result(r) for r in result)
16+
if isinstance(result, tuple):
17+
return tuple(_wrap_result(r) for r in result)
1418
return (
1519
NestedTensor(result)
1620
if torch.is_tensor(result) and torch.ops.nestedtensor.is_nested_tensor_impl(result)

nestedtensor/version.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
__version__ = '0.0.1.dev202061718+94aadd5'
2-
git_version = '94aadd589c657f116508eb3547104b19ddbcb63c'
1+
__version__ = '0.0.1.dev202071515+2fb94d8'
2+
git_version = '2fb94d8d788650f4bf4340988b0a2c0a3684fbe2'
33
from nestedtensor import _C
44
if hasattr(_C, 'CUDA_VERSION'):
55
cuda = _C.CUDA_VERSION

0 commit comments

Comments
 (0)