-
Notifications
You must be signed in to change notification settings - Fork 630
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix as_tensor not keeping the parent alive in Python #60
Merged
Merged
Changes from 1 commit
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,8 +40,12 @@ typedef vector<Index> Dims; | |
template <typename Backend> | ||
class DLL_PUBLIC TensorList : public Buffer<Backend> { | ||
public: | ||
DLL_PUBLIC TensorList() : layout_(DALI_NHWC) {} | ||
DLL_PUBLIC ~TensorList() = default; | ||
DLL_PUBLIC TensorList() : layout_(DALI_NHWC), | ||
tensor_view_(nullptr) {} | ||
|
||
DLL_PUBLIC ~TensorList() { | ||
delete tensor_view_; | ||
} | ||
|
||
/** | ||
* @brief Resizes this TensorList to match the shape of the input. | ||
|
@@ -104,6 +108,10 @@ class DLL_PUBLIC TensorList : public Buffer<Backend> { | |
} | ||
DALI_ENFORCE(new_size >= 0, "Invalid negative buffer size."); | ||
|
||
// Tensor view of this TensorList is no longer valid | ||
delete tensor_view_; | ||
tensor_view_ = nullptr; | ||
|
||
// Resize the underlying allocation and save the new shape | ||
ResizeHelper(new_size); | ||
shape_ = new_shape; | ||
|
@@ -116,15 +124,19 @@ class DLL_PUBLIC TensorList : public Buffer<Backend> { | |
* | ||
* When this function is called, the calling object shares the | ||
* underlying allocation of the input TensorList. Its size, type | ||
* and shape are set to match the calling TensorList. While this | ||
* list shares data with another list, 'shares_data()' will | ||
* and shape are set to match the calling TensorList. While this | ||
* list shares data with another list, 'shares_data()' will | ||
* return 'true'. | ||
*/ | ||
DLL_PUBLIC inline void ShareData(TensorList<Backend> *other) { | ||
DALI_ENFORCE(other != nullptr, "Input TensorList is nullptr"); | ||
DALI_ENFORCE(IsValidType(other->type_), "To share data, " | ||
"the input TensorList must have a valid data type"); | ||
|
||
// Tensor view of this TensorList is no longer valid | ||
delete tensor_view_; | ||
tensor_view_ = nullptr; | ||
|
||
// Save the calling TensorLists meta-data | ||
data_ = other->data_; | ||
shape_ = other->shape_; | ||
|
@@ -143,10 +155,10 @@ class DLL_PUBLIC TensorList : public Buffer<Backend> { | |
* @brief Wraps the raw allocation. The input pointer must not be nullptr. | ||
* if the size of the allocation is zero, the TensorList is reset to | ||
* a default state and is NOT marked as sharing data. | ||
* | ||
* After wrapping the allocation, the TensorLists size is set to 0, | ||
* and its type is reset to NoType. Future calls to Resize or setting | ||
* of the Tensor type will evaluate whether or not the current | ||
* | ||
* After wrapping the allocation, the TensorLists size is set to 0, | ||
* and its type is reset to NoType. Future calls to Resize or setting | ||
* of the Tensor type will evaluate whether or not the current | ||
* allocation is large enough to be used and proceed appropriately. | ||
* | ||
* The TensorList object assumes no ownership of the input allocation, | ||
|
@@ -157,6 +169,10 @@ class DLL_PUBLIC TensorList : public Buffer<Backend> { | |
DLL_PUBLIC inline void ShareData(void *ptr, size_t bytes) { | ||
DALI_ENFORCE(ptr != nullptr, "Input pointer must not be nullptr."); | ||
|
||
// Tensor view of this TensorList is no longer valid | ||
delete tensor_view_; | ||
tensor_view_ = nullptr; | ||
|
||
// Save our new pointer and bytes. Reset our type, shape, and size | ||
data_.reset(ptr, [](void *) {}); | ||
num_bytes_ = bytes; | ||
|
@@ -266,6 +282,16 @@ class DLL_PUBLIC TensorList : public Buffer<Backend> { | |
return true; | ||
} | ||
|
||
Tensor<Backend> * AsTensor() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add docs to this method. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. |
||
if (tensor_view_ == nullptr) { | ||
tensor_view_ = new Tensor<Backend>(); | ||
tensor_view_->ShareData(this); | ||
} | ||
|
||
return tensor_view_; | ||
} | ||
|
||
|
||
// So we can access the members of other TensorListes | ||
// with different template types | ||
template <typename InBackend> | ||
|
@@ -289,6 +315,12 @@ class DLL_PUBLIC TensorList : public Buffer<Backend> { | |
vector<Index> offsets_; | ||
DALITensorLayout layout_; | ||
|
||
// In order to not leak memory (and make it slightly faster) | ||
// when sharing data with a Tensor, we will store a pointer to | ||
// Tensor that shares the data with this TensorList (valid only | ||
// if IsDenseTensor returns true) | ||
Tensor<Backend> * tensor_view_; | ||
|
||
USE_BUFFER_MEMBERS(); | ||
}; | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if python will handle this, what if we call AsTensor, return pointer, call ShareData, then pointer we just returned to python is no longer valid but python still can reference it, or I'm wrong?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a much longer discussion - in order to handle this we would need to store information about all the objects sharing data with a particular tensor or tensorlist and notify them that the data is no longer valid. This would potentially have performance implications.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see what you mean. I was referring to the other problem with sharing data. Hmmm, instead of deleting the pointer I guess I could keep it but invoke
ShareData
again after the internal pointer todata_
changes.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is some solution to our problem.
Regarding cores reference I following approaches
First and second requires a lot of synchronization that could hit performance.
Another solution is to leave it as design assumption - like iterator invalidation when data we are iterating over changes.