Skip to content
This repository was archived by the owner on Nov 15, 2022. It is now read-only.

Change as_nested_tensor to only forward NestedTensor type #153

Merged
merged 8 commits into from
May 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions nestedtensor/csrc/creation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ NestedNode<c10::IValue> py_to_nested_tensor(const py::object& py_obj) {
}
}

THPNestedTensor as_nested_tensor(py::sequence list) {
THPNestedTensor nested_tensor(py::sequence list) {
NestedNode<c10::IValue> ivalue_structure = py_to_nested_tensor(list);
auto fn = [](c10::IValue a, bool result) { return result && a.isTensor(); };
bool all_same =
Expand All @@ -197,7 +197,7 @@ THPNestedTensor as_nested_tensor(py::sequence list) {
_verify_variables(*first, structure, true);
}
}
return THPNestedTensor(NestedTensor(std::move(structure)));
return THPNestedTensor(NestedTensor(std::move(structure)).contiguous());
}

} // namespace nested_tensor
Expand Down
2 changes: 1 addition & 1 deletion nestedtensor/csrc/creation.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace nested_tensor {

NestedNode<py::object> py_to_nested_node(py::object&& py_obj);

THPNestedTensor as_nested_tensor(pybind11::sequence list);
THPNestedTensor nested_tensor(pybind11::sequence list);

} // namespace nested_tensor
} // namespace torch
2 changes: 1 addition & 1 deletion nestedtensor/csrc/py_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

// NOTE: This is a private function until it is feature complete
m.def("_jit_tensorwise", &torch::nested_tensor::jit_tensorwise);
m.def("as_nested_tensor", &torch::nested_tensor::as_nested_tensor);
m.def("nested_tensor", &torch::nested_tensor::nested_tensor);
}
22 changes: 10 additions & 12 deletions nestedtensor/nested/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,12 @@
from . import utils
from nestedtensor import _C

def as_nested_tensor(data, dtype=None, device=None):
# Simple wrapper around a nested list of Tensors.
# Shares memory with original objects.
# # TODO: Needs tests to check failure cases
ret_impl = _C.as_nested_tensor(data)
ret = nested.NestedTensor(ret_impl)
if dtype is not None:
ret = ret.to(dtype)
if device is not None:
ret = ret.to(device)
return ret

def nested_tensor(data, dtype=None, device=None, requires_grad=False, pin_memory=False):
"""
Arguments match torch.tensor
"""
result = as_nested_tensor(data).contiguous()
result = nested.NestedTensor(_C.nested_tensor(data))

if dtype is not None or device is not None:
result = result.to(dtype=dtype, device=device)
Expand All @@ -30,3 +19,12 @@ def nested_tensor(data, dtype=None, device=None, requires_grad=False, pin_memory
if pin_memory:
result = result.pin_memory()
return result


def as_nested_tensor(data, dtype=None, device=None):
# TODO: Needs tests to check failure cases
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just do it(c)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow-up PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because there is potentially a lot of edge cases

if not utils.is_nested_tensor(data):
data = nested_tensor(data, dtype, device)
if dtype is not None or device is not None:
return data.to(dtype=dtype, device=device)
return data
4 changes: 2 additions & 2 deletions nestedtensor/version.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__version__ = '0.0.1.dev202052020+b9fcbb0'
git_version = 'b9fcbb0bf8f205934a493d8bd308c69bd09c9d2e'
__version__ = '0.0.1.dev202052023+2de5c5f'
git_version = '2de5c5fb95e5fdc06d1d38ba14bf47d2d31c8f7e'
from nestedtensor import _C
if hasattr(_C, 'CUDA_VERSION'):
cuda = _C.CUDA_VERSION
10 changes: 5 additions & 5 deletions test/test_nested_tensor_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def some_func(x):
nt_sum_res = some_func(nt)
nt_sum_res.backward()
self.assertEqual(sum_res, nt_sum_res)
self.assertEqual(verification_tensor.grad.data, nt[0].grad.data)
self.assertEqual(verification_tensor.grad.data, tensor.grad.data)
self.assertIsNone(nt[0].grad)
self.assertIsNotNone(verification_tensor.grad)

# nested_tensor constructor
tensor2 = torch.tensor([[1, 2], [3, 4]], dtype=torch.float, requires_grad=True)
Expand Down Expand Up @@ -55,9 +55,9 @@ def some_func(x):
sum_res.backward()

self.assertEqual(sum_res, nt_sum_res)
self.assertEqual(nt1[0].grad.data, tensor.grad[0].data)
self.assertEqual(nt1[1].grad.data, tensor.grad[1].data.masked_select(mask[1]))
self.assertEqual(nt1[2].grad.data, tensor.grad[2].data.masked_select(mask[2]))
self.assertIsNone(nt1[0].grad)
self.assertIsNone(nt1[1].grad)
self.assertIsNone(nt1[2].grad)

self.assertIsNone(nt2[0].grad)
self.assertIsNone(nt2[1].grad)
Expand Down
57 changes: 27 additions & 30 deletions test/test_nested_tensor_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,22 +53,22 @@ def test_list_constructor(self):
for i in range(num_tensors):
tensors[i].mul_(i + 2)
for i in range(num_tensors):
self.assertEqual(tensors[i], nested_tensor.unbind()[i])
self.assertEqual(tensors[i].storage().data_ptr(
self.assertNotEqual(tensors[i], nested_tensor.unbind()[i])
self.assertNotEqual(tensors[i].storage().data_ptr(
), nested_tensor.unbind()[i].storage().data_ptr())

def test_mutation(self):
def test_as_nested_tensor(self):
tensors = []
num_tensors = 16
for i in range(num_tensors):
tensors.append(utils.gen_float_tensor(i, (i + 1, 128, 128)))

# This should create references
# This should NOT create references
nested_tensor = nestedtensor.as_nested_tensor(tensors)
for i in range(num_tensors):
tensors[i].mul_(i + 2)
for i in range(num_tensors):
self.assertEqual(tensors[i], nested_tensor.unbind()[i])
self.assertNotEqual(tensors[i], nested_tensor.unbind()[i])

# This should NOT create references
nested_tensor = nestedtensor.nested_tensor(tensors)
Expand All @@ -77,6 +77,11 @@ def test_mutation(self):
for i in range(num_tensors):
self.assertNotEqual(tensors[i], nested_tensor.unbind()[i])

nested_tensor1 = nestedtensor.as_nested_tensor(nested_tensor)
self.assertTrue(nested_tensor1 is nested_tensor)
nested_tensor2 = nestedtensor.as_nested_tensor(nested_tensor, dtype=torch.int64)
self.assertTrue(nested_tensor2 is not nested_tensor)

def test_constructor(self):
for constructor in _iter_constructors():
self.assertRaises(
Expand Down Expand Up @@ -229,23 +234,15 @@ def test_nested_size(self):
self.assertRaises(IndexError, lambda: a.nested_size(2))

def test_nested_stride(self):
tensors = [torch.rand(1, 2, 4)[:, :, 0], torch.rand(
2, 3, 4)[:, 1, :], torch.rand(3, 4, 5)[1, :, :]]
a = nestedtensor.as_nested_tensor(tensors)
na = tuple(tuple(t.stride()) for t in tensors)
ans = a.nested_stride()
result = tuple(ans[i] for i in range(len(ans)))
for r, s in zip(result, na):
self.assertEqual(r, s)

tensors = [torch.rand(1, 2, 4)[:, :, 0], torch.rand(
2, 3, 4)[:, 1, :], torch.rand(3, 4, 5)[1, :, :]]
a = nestedtensor.nested_tensor(tensors)
na = list(list(t.contiguous().stride()) for t in tensors)
ans = a.nested_stride()
result = tuple(ans[i] for i in range(len(ans)))
for r, s in zip(result, na):
self.assertEqual(r, s)
for constructor in _iter_constructors():
tensors = [torch.rand(1, 2, 4)[:, :, 0], torch.rand(
2, 3, 4)[:, 1, :], torch.rand(3, 4, 5)[1, :, :]]
a = constructor(tensors)
na = list(list(t.contiguous().stride()) for t in tensors)
ans = a.nested_stride()
result = tuple(ans[i] for i in range(len(ans)))
for r, s in zip(result, na):
self.assertEqual(r, s)

def test_len(self):
for constructor in _iter_constructors():
Expand Down Expand Up @@ -310,19 +307,19 @@ def test_unbind(self):
# TODO: contiguous nestedtensors should return tuples of contiguous nestedtensors on dimension 0

def _test(a, b, c, d, e):
nt = nestedtensor.as_nested_tensor([a, b])
nt = nestedtensor.nested_tensor([a, b])
a1, b1 = nt.unbind()
self.assertTrue(a is a1)
self.assertTrue(b is b1)
self.assertTrue(a is not a1)
self.assertTrue(b is not b1)

nt1 = nestedtensor.as_nested_tensor([[c, d], [e]])
nt1 = nestedtensor.nested_tensor([[c, d], [e]])
nt11, nt12 = nt1.unbind()
c1, d1 = nt11.unbind()
e1 = nt12.unbind()[0]

self.assertTrue(c is c1)
self.assertTrue(d is d1)
self.assertTrue(e is e1)
self.assertTrue(c is not c1)
self.assertTrue(d is not d1)
self.assertTrue(e is not e1)

nt = nestedtensor.nested_tensor([a, b])
a1, b1 = nt.unbind()
Expand Down Expand Up @@ -610,7 +607,7 @@ def test_contiguous(self):
torch.tensor([3, 4]),
torch.tensor([5, 6]),
torch.tensor([7, 8])])
self.assertTrue(not a.is_contiguous())
self.assertTrue(a.is_contiguous())


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion test/test_nested_tensor_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def test_functional_relu_(self):
nt1 = nestedtensor.as_nested_tensor([t_clone])
torch.nn.functional.relu_(nt1)
self.assertEqual(nt1, expected_nt)
self.assertEqual(t_clone, expected_t)
self.assertNotEqual(t_clone, expected_t)

def test_nn_relu(self):
inputs = [
Expand Down