diff --git a/dali/pipeline/data/tensor_list.h b/dali/pipeline/data/tensor_list.h index c916ffb481f..7e3a8a92cd4 100644 --- a/dali/pipeline/data/tensor_list.h +++ b/dali/pipeline/data/tensor_list.h @@ -40,8 +40,12 @@ typedef vector Dims; template class DLL_PUBLIC TensorList : public Buffer { public: - DLL_PUBLIC TensorList() : layout_(DALI_NHWC) {} - DLL_PUBLIC ~TensorList() = default; + DLL_PUBLIC TensorList() : layout_(DALI_NHWC), + tensor_view_(nullptr) {} + + DLL_PUBLIC ~TensorList() { + delete tensor_view_; + } /** * @brief Resizes this TensorList to match the shape of the input. @@ -107,6 +111,11 @@ class DLL_PUBLIC TensorList : public Buffer { // Resize the underlying allocation and save the new shape ResizeHelper(new_size); shape_ = new_shape; + + // Tensor view of this TensorList is no longer valid + if (tensor_view_) { + tensor_view_->ShareData(this); + } } /** @@ -116,8 +125,8 @@ class DLL_PUBLIC TensorList : public Buffer { * * When this function is called, the calling object shares the * underlying allocation of the input TensorList. Its size, type - * and shape are set to match the calling TensorList. While this - * list shares data with another list, 'shares_data()' will + * and shape are set to match the calling TensorList. While this + * list shares data with another list, 'shares_data()' will * return 'true'. */ DLL_PUBLIC inline void ShareData(TensorList *other) { @@ -134,6 +143,11 @@ class DLL_PUBLIC TensorList : public Buffer { num_bytes_ = other->num_bytes_; device_ = other->device_; + // Tensor view of this TensorList is no longer valid + if (tensor_view_) { + tensor_view_->ShareData(this); + } + // If the other tensor has a non-zero size allocation, mark that // we are now sharing an allocation with another buffer shares_data_ = num_bytes_ > 0 ? true : false; @@ -143,10 +157,10 @@ class DLL_PUBLIC TensorList : public Buffer { * @brief Wraps the raw allocation. The input pointer must not be nullptr. * if the size of the allocation is zero, the TensorList is reset to * a default state and is NOT marked as sharing data. - * - * After wrapping the allocation, the TensorLists size is set to 0, - * and its type is reset to NoType. Future calls to Resize or setting - * of the Tensor type will evaluate whether or not the current + * + * After wrapping the allocation, the TensorLists size is set to 0, + * and its type is reset to NoType. Future calls to Resize or setting + * of the Tensor type will evaluate whether or not the current * allocation is large enough to be used and proceed appropriately. * * The TensorList object assumes no ownership of the input allocation, @@ -165,6 +179,11 @@ class DLL_PUBLIC TensorList : public Buffer { offsets_.clear(); size_ = 0; + // Tensor view of this TensorList is no longer valid + if (tensor_view_) { + tensor_view_->ShareData(this); + } + // If the input pointer stores a non-zero size allocation, mark // that we are sharing our underlying data shares_data_ = num_bytes_ > 0 ? true : false; @@ -266,6 +285,22 @@ class DLL_PUBLIC TensorList : public Buffer { return true; } + /** + * @brief Returns a Tensor which shares the data + * with this TensorList. The tensor obtained + * through this function stays valid for the lifetime + * of the parent TensorList. + */ + Tensor * AsTensor() { + if (tensor_view_ == nullptr) { + tensor_view_ = new Tensor(); + tensor_view_->ShareData(this); + } + + return tensor_view_; + } + + // So we can access the members of other TensorListes // with different template types template @@ -289,6 +324,12 @@ class DLL_PUBLIC TensorList : public Buffer { vector offsets_; DALITensorLayout layout_; + // In order to not leak memory (and make it slightly faster) + // when sharing data with a Tensor, we will store a pointer to + // Tensor that shares the data with this TensorList (valid only + // if IsDenseTensor returns true) + Tensor * tensor_view_; + USE_BUFFER_MEMBERS(); }; diff --git a/dali/python/backend_impl.cc b/dali/python/backend_impl.cc index c03d652bb5c..0318e2953a9 100644 --- a/dali/python/backend_impl.cc +++ b/dali/python/backend_impl.cc @@ -295,18 +295,13 @@ void ExposeTensorList(py::module &m) { // NOLINT Parameters ---------- )code") - .def("as_tensor", - [](TensorList &t) -> Tensor* { - Tensor * ret = new Tensor(); - ret->ShareData(&t); - return ret; - }, + .def("as_tensor", &TensorList::AsTensor, R"code( Returns a tensor that is a view of this `TensorList`. This function can only be called if `is_dense_tensor` returns `True`. )code", - py::return_value_policy::take_ownership); + py::return_value_policy::reference_internal); py::class_>(m, "TensorListGPU", py::buffer_protocol()) .def("__init__", [](TensorList &t) { @@ -357,18 +352,13 @@ void ExposeTensorList(py::module &m) { // NOLINT Parameters ---------- )code") - .def("as_tensor", - [](TensorList &t) -> Tensor* { - Tensor * ret = new Tensor(); - ret->ShareData(&t); - return ret; - }, + .def("as_tensor", &TensorList::AsTensor, R"code( Returns a tensor that is a view of this `TensorList`. This function can only be called if `is_dense_tensor` returns `True`. )code", - py::return_value_policy::take_ownership); + py::return_value_policy::reference_internal); } static vector GetRegisteredCPUOps() {