Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix as_tensor not keeping the parent alive in Python #60

Merged
merged 4 commits into from
Jul 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 49 additions & 8 deletions dali/pipeline/data/tensor_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,12 @@ typedef vector<Index> Dims;
template <typename Backend>
class DLL_PUBLIC TensorList : public Buffer<Backend> {
public:
DLL_PUBLIC TensorList() : layout_(DALI_NHWC) {}
DLL_PUBLIC ~TensorList() = default;
DLL_PUBLIC TensorList() : layout_(DALI_NHWC),
tensor_view_(nullptr) {}

DLL_PUBLIC ~TensorList() {
delete tensor_view_;
}

/**
* @brief Resizes this TensorList to match the shape of the input.
Expand Down Expand Up @@ -107,6 +111,11 @@ class DLL_PUBLIC TensorList : public Buffer<Backend> {
// Resize the underlying allocation and save the new shape
ResizeHelper(new_size);
shape_ = new_shape;

// Tensor view of this TensorList is no longer valid
if (tensor_view_) {
tensor_view_->ShareData(this);
}
}

/**
Expand All @@ -116,8 +125,8 @@ class DLL_PUBLIC TensorList : public Buffer<Backend> {
*
* When this function is called, the calling object shares the
* underlying allocation of the input TensorList. Its size, type
* and shape are set to match the calling TensorList. While this
* list shares data with another list, 'shares_data()' will
* and shape are set to match the calling TensorList. While this
* list shares data with another list, 'shares_data()' will
* return 'true'.
*/
DLL_PUBLIC inline void ShareData(TensorList<Backend> *other) {
Expand All @@ -134,6 +143,11 @@ class DLL_PUBLIC TensorList : public Buffer<Backend> {
num_bytes_ = other->num_bytes_;
device_ = other->device_;

// Tensor view of this TensorList is no longer valid
if (tensor_view_) {
tensor_view_->ShareData(this);
}

// If the other tensor has a non-zero size allocation, mark that
// we are now sharing an allocation with another buffer
shares_data_ = num_bytes_ > 0 ? true : false;
Expand All @@ -143,10 +157,10 @@ class DLL_PUBLIC TensorList : public Buffer<Backend> {
* @brief Wraps the raw allocation. The input pointer must not be nullptr.
* if the size of the allocation is zero, the TensorList is reset to
* a default state and is NOT marked as sharing data.
*
* After wrapping the allocation, the TensorLists size is set to 0,
* and its type is reset to NoType. Future calls to Resize or setting
* of the Tensor type will evaluate whether or not the current
*
* After wrapping the allocation, the TensorLists size is set to 0,
* and its type is reset to NoType. Future calls to Resize or setting
* of the Tensor type will evaluate whether or not the current
* allocation is large enough to be used and proceed appropriately.
*
* The TensorList object assumes no ownership of the input allocation,
Expand All @@ -165,6 +179,11 @@ class DLL_PUBLIC TensorList : public Buffer<Backend> {
offsets_.clear();
size_ = 0;

// Tensor view of this TensorList is no longer valid
if (tensor_view_) {
tensor_view_->ShareData(this);
}

// If the input pointer stores a non-zero size allocation, mark
// that we are sharing our underlying data
shares_data_ = num_bytes_ > 0 ? true : false;
Expand Down Expand Up @@ -266,6 +285,22 @@ class DLL_PUBLIC TensorList : public Buffer<Backend> {
return true;
}

/**
* @brief Returns a Tensor which shares the data
* with this TensorList. The tensor obtained
* through this function stays valid for the lifetime
* of the parent TensorList.
*/
Tensor<Backend> * AsTensor() {
if (tensor_view_ == nullptr) {
tensor_view_ = new Tensor<Backend>();
tensor_view_->ShareData(this);
}

return tensor_view_;
}


// So we can access the members of other TensorListes
// with different template types
template <typename InBackend>
Expand All @@ -289,6 +324,12 @@ class DLL_PUBLIC TensorList : public Buffer<Backend> {
vector<Index> offsets_;
DALITensorLayout layout_;

// In order to not leak memory (and make it slightly faster)
// when sharing data with a Tensor, we will store a pointer to
// Tensor that shares the data with this TensorList (valid only
// if IsDenseTensor returns true)
Tensor<Backend> * tensor_view_;

USE_BUFFER_MEMBERS();
};

Expand Down
18 changes: 4 additions & 14 deletions dali/python/backend_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,18 +295,13 @@ void ExposeTensorList(py::module &m) { // NOLINT
Parameters
----------
)code")
.def("as_tensor",
[](TensorList<CPUBackend> &t) -> Tensor<CPUBackend>* {
Tensor<CPUBackend> * ret = new Tensor<CPUBackend>();
ret->ShareData(&t);
return ret;
},
.def("as_tensor", &TensorList<CPUBackend>::AsTensor,
R"code(
Returns a tensor that is a view of this `TensorList`.

This function can only be called if `is_dense_tensor` returns `True`.
)code",
py::return_value_policy::take_ownership);
py::return_value_policy::reference_internal);

py::class_<TensorList<GPUBackend>>(m, "TensorListGPU", py::buffer_protocol())
.def("__init__", [](TensorList<GPUBackend> &t) {
Expand Down Expand Up @@ -357,18 +352,13 @@ void ExposeTensorList(py::module &m) { // NOLINT
Parameters
----------
)code")
.def("as_tensor",
[](TensorList<GPUBackend> &t) -> Tensor<GPUBackend>* {
Tensor<GPUBackend> * ret = new Tensor<GPUBackend>();
ret->ShareData(&t);
return ret;
},
.def("as_tensor", &TensorList<GPUBackend>::AsTensor,
R"code(
Returns a tensor that is a view of this `TensorList`.

This function can only be called if `is_dense_tensor` returns `True`.
)code",
py::return_value_policy::take_ownership);
py::return_value_policy::reference_internal);
}

static vector<string> GetRegisteredCPUOps() {
Expand Down