Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend and relax TensorList sample APIs #4358

Merged
merged 3 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 44 additions & 7 deletions dali/pipeline/data/tensor_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ TensorList<Backend> &TensorList<Backend>::operator=(TensorList<Backend> &&other)
template <typename Backend>
void TensorList<Backend>::VerifySampleShareCompatibility(DALIDataType type, int sample_dim,
TensorLayout layout, bool pinned,
AccessOrder order, int device_id,
int device_id,
const std::string &error_suffix) {
// Checks in the order of class members
DALI_ENFORCE(this->type() == type,
Expand All @@ -241,9 +241,6 @@ void TensorList<Backend>::VerifySampleShareCompatibility(DALIDataType type, int
make_string("Sample must have the same pinned status as target batch, current: ",
this->is_pinned(), ", new: ", pinned, error_suffix));

DALI_ENFORCE(this->order() == order,
make_string("Sample must have the same order as the target batch", error_suffix));

DALI_ENFORCE(this->device_id() == device_id,
make_string("Sample must have the same device id as target batch, current: ",
this->device_id(), ", new: ", device_id, error_suffix));
Expand All @@ -261,7 +258,7 @@ void TensorList<Backend>::SetSample(int sample_idx, const TensorList<Backend> &s
if (&src.tensors_[src_sample_idx] == &tensors_[sample_idx])
return;
VerifySampleShareCompatibility(src.type(), src.shape().sample_dim(), src.GetLayout(),
src.is_pinned(), src.order(), src.device_id(),
src.is_pinned(), src.device_id(),
make_string(" for source sample idx: ", src_sample_idx,
" and target sample idx: ", sample_idx, "."));

Expand All @@ -270,6 +267,11 @@ void TensorList<Backend>::SetSample(int sample_idx, const TensorList<Backend> &s
// Setting a new share overwrites the previous one - so we can safely assume that even if
// we had a sample sharing into TL, it will be overwritten
tensors_[sample_idx].ShareData(src.tensors_[src_sample_idx]);
// As the order was simply copied over, we have to fix it back.
// We will be accessing it in order of this buffer, so we need to wait for all the work
// from the "incoming" src order.
tensors_[sample_idx].set_order(order(), false);
order().wait(src.order());

if (src.GetLayout().empty() && !GetLayout().empty()) {
tensors_[sample_idx].SetLayout(GetLayout());
Expand All @@ -284,14 +286,19 @@ void TensorList<Backend>::SetSample(int sample_idx, const Tensor<Backend> &owner
// Setting any individual sample converts the batch to non-contiguous mode
MakeNoncontiguous();
VerifySampleShareCompatibility(owner.type(), owner.shape().sample_dim(), owner.GetLayout(),
owner.is_pinned(), owner.order(), owner.device_id(),
owner.is_pinned(), owner.device_id(),
make_string(" for sample idx: ", sample_idx, "."));

shape_.set_tensor_shape(sample_idx, owner.shape());

// Setting a new share overwrites the previous one - so we can safely assume that even if
// we had a sample sharing into TL, it will be overwritten
tensors_[sample_idx].ShareData(owner);
// As the order was simply copied over, we have to fix it back.
// We will be accessing it in order of this buffer, so we need to wait for all the work
// from the "incoming" src order.
tensors_[sample_idx].set_order(order(), false);
order().wait(owner.order());

if (owner.GetLayout().empty() && !GetLayout().empty()) {
tensors_[sample_idx].SetLayout(GetLayout());
Expand All @@ -307,7 +314,7 @@ void TensorList<Backend>::SetSample(int sample_idx, const shared_ptr<void> &ptr,
assert(sample_idx >= 0 && sample_idx < curr_num_tensors_);
// Setting any individual sample converts the batch to non-contiguous mode
MakeNoncontiguous();
VerifySampleShareCompatibility(type, shape.sample_dim(), layout, pinned, order, device_id,
VerifySampleShareCompatibility(type, shape.sample_dim(), layout, pinned, device_id,
make_string(" for sample idx: ", sample_idx, "."));

DALI_ENFORCE(!IsContiguous());
Expand All @@ -316,6 +323,11 @@ void TensorList<Backend>::SetSample(int sample_idx, const shared_ptr<void> &ptr,
// Setting a new share overwrites the previous one - so we can safely assume that even if
// we had a sample sharing into TL, it will be overwritten
tensors_[sample_idx].ShareData(ptr, bytes, pinned, shape, type, device_id, order);
// As the order was simply copied over, we have to fix it back.
// We will be accessing it in order of this buffer, so we need to wait for all the work
// from the "incoming" src order.
tensors_[sample_idx].set_order(this->order(), false);
this->order().wait(order);

if (layout.empty() && !GetLayout().empty()) {
tensors_[sample_idx].SetLayout(GetLayout());
Expand Down Expand Up @@ -573,6 +585,23 @@ void TensorList<Backend>::Resize(const TensorListShape<> &new_shape, DALIDataTyp
}


template <typename Backend>
void TensorList<Backend>::ResizeSample(int sample_idx, const TensorShape<> &new_shape) {
DALI_ENFORCE(IsValidType(type()),
"Sample in TensorList cannot be resized with invalid type. Set the type first for "
"the whole TensorList using set_type or Resize.");
DALI_ENFORCE(sample_dim() == new_shape.sample_dim(),
"Sample in TensorList cannot be resized with non-compatible batch dimension. Use "
"set_sample_dim or Resize to set correct sample dimension for the whole batch.");
// Bounds check
assert(sample_idx >= 0 && sample_idx < curr_num_tensors_);
// Resizing any individual sample converts the batch to non-contiguous mode
MakeNoncontiguous();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
MakeNoncontiguous();
if (volume[new_shape] != volume[shape_.set_tensor_shape(sample_idx)])
MakeNoncontiguous();

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed on slack, I think if we want to have a variant that doesn't break the contiguity, it should be a separate call like ReshapeSample or ResizeSample with some parameter, that would always enforce that we have a valid volume. That way the bahaviour is consistent and we know the postconditions of the call always.

shape_.set_tensor_shape(sample_idx, new_shape);
tensors_[sample_idx].Resize(new_shape);
}


template <typename Backend>
void TensorList<Backend>::SetSize(int new_size) {
DALI_ENFORCE(new_size >= 0, make_string("Incorrect size: ", new_size));
Expand Down Expand Up @@ -821,6 +850,14 @@ void TensorList<Backend>::Copy(const TensorList<SrcBackend> &src, AccessOrder or
bool use_copy_kernel) {
auto copy_order = copy_impl::SyncBefore(this->order(), src.order(), order);

if (!IsValidType(src.type())) {
assert(!src.has_data() && "It is not possible to have data without valid type.");
Reset();
SetLayout(src.GetLayout());
// no copying to do
return;
}

Resize(src.shape(), src.type());
// After resize the state_, curr_num_tensors_, type_, sample_dim_, shape_ (and pinned)
// postconditions are met, as well as the buffers are correctly adjusted.
Expand Down
35 changes: 27 additions & 8 deletions dali/pipeline/data/tensor_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,11 @@ class DLL_PUBLIC TensorList {
* function would still report that they are sharing data. It is advised that all samples are
* replaced this way otherwise the contiguous allocation would be kept alive.
*
* The metadata (pinned, type, device_id, order, layout) must match what is already set
* for the whole batch to maintain consistency.
* The metadata (pinned, type, device_id, layout) must match what is already set for the whole
* batch to maintain consistency.
*
* We wait for the order of incoming sample in the order of the batch to allow correctly ordered
* access of the new sample.
*
* @param sample_idx index of sample to be set
* @param src owner of source sample
Expand All @@ -238,8 +241,11 @@ class DLL_PUBLIC TensorList {
* function would still report that they are sharing data. It is advised that all samples are
* replaced this way otherwise the contiguous allocation would be kept alive.
*
* The metadata (pinned, type, device_id, order, layout) must match what is already set
* for the whole batch to maintain consistency.
* The metadata (pinned, type, device_id, layout) must match what is already set for the whole
* batch to maintain consistency.
*
* We wait for the order of incoming sample in the order of the batch to allow correctly ordered
* access of the new sample.
*
* @param sample_idx index of sample to be set
* @param src sample owner
Expand All @@ -256,12 +262,15 @@ class DLL_PUBLIC TensorList {
* function would still report that they are sharing data. It is advised that all samples are
* replaced this way otherwise the contiguous allocation would be kept alive.
*
* The metadata (pinned, type, device_id, order, layout) must match what is already set
* for the whole batch to maintain consistency.
* The metadata (pinned, type, device_id, layout) must match what is already set for the whole
* batch to maintain consistency.
*
* We wait for the order of incoming sample in the order of the batch to allow correctly ordered
* access of the new sample.
*/
DLL_PUBLIC void SetSample(int sample_idx, const shared_ptr<void> &ptr, size_t bytes, bool pinned,
const TensorShape<> &shape, DALIDataType type, int device_id,
AccessOrder order = {}, const TensorLayout &layout = "");
AccessOrder order, const TensorLayout &layout = "");
/** @} */

/**
Expand Down Expand Up @@ -365,6 +374,16 @@ class DLL_PUBLIC TensorList {
DLL_PUBLIC void Resize(const TensorListShape<> &new_shape, DALIDataType new_type,
BatchContiguity state = BatchContiguity::Automatic);

/**
* @brief Resize individual sample. Allowed only in non-contiguous mode - it will convert the
* TensorList on the first call. The type must be already known, and the TensorList must heave
* enough elements for this operation.
*
* @param sample_idx sample index to be resized
* @param new_shape requested shape
*/
DLL_PUBLIC void ResizeSample(int sample_idx, const TensorShape<> &new_shape);

/**
* @brief Reserve memory as one contiguous allocation
*/
Expand Down Expand Up @@ -703,7 +722,7 @@ class DLL_PUBLIC TensorList {
* @param error_suffix Additional description added to the error message
*/
void VerifySampleShareCompatibility(DALIDataType type, int sample_dim, TensorLayout layout,
bool pinned, AccessOrder order, int device_id,
bool pinned, int device_id,
const std::string &error_suffix = ".");

/**
Expand Down
78 changes: 73 additions & 5 deletions dali/pipeline/data/tensor_list_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ TYPED_TEST(TensorListTest, TestCopy) {

TYPED_TEST(TensorListTest, TestCopyEmpty) {
using Backend = std::tuple_element_t<0, TypeParam>;
TensorList<Backend> tl;
TensorList<Backend> tl, uninitialized;
tl.SetContiguity(this->kContiguity);

tl.template set_type<float>();
Expand All @@ -524,6 +524,13 @@ TYPED_TEST(TensorListTest, TestCopyEmpty) {
ASSERT_EQ(tl.type(), tl2.type());
ASSERT_EQ(tl._num_elements(), tl2._num_elements());
ASSERT_EQ(tl.GetLayout(), tl2.GetLayout());

tl2.Copy(uninitialized);
ASSERT_FALSE(tl2.has_data());
ASSERT_EQ(uninitialized.num_samples(), tl2.num_samples());
ASSERT_EQ(uninitialized.type(), tl2.type());
ASSERT_EQ(uninitialized._num_elements(), tl2._num_elements());
ASSERT_EQ(uninitialized.GetLayout(), tl2.GetLayout());
}
}

Expand Down Expand Up @@ -954,10 +961,6 @@ std::vector<std::pair<std::string, std::function<void(TensorList<Backend> &)>>>
{"layout", [layout](TensorList<Backend> &t) { t.SetLayout(layout); }},
{"device id", [device_id](TensorList<Backend> &t) { t.set_device_id(device_id); }},
{"pinned", [pinned](TensorList<Backend> &t) { t.set_pinned(pinned); }},
{"order",
[device_id](TensorList<Backend> &t) {
t.set_order(AccessOrder(cuda_stream, device_id));
}},
};
}

Expand Down Expand Up @@ -1327,6 +1330,71 @@ TYPED_TEST(TensorListSuite, NoncontiguousResize) {
}


TYPED_TEST(TensorListSuite, ResizeSample) {
TensorList<TypeParam> tv;
tv.SetContiguity(BatchContiguity::Automatic);

auto new_shape = TensorListShape<>{{1, 2, 3}, {2, 3, 4}, {3, 4, 5}};
tv.Resize(new_shape, DALI_FLOAT, BatchContiguity::Contiguous);
EXPECT_TRUE(tv.IsContiguous());

for (int i = 0; i < 3; i++) {
FillWithNumber(tv[i], 1 + i * 1.f);
}

for (int i = 0; i < 3; i++) {
EXPECT_NE(tv[i].raw_data(), nullptr);
EXPECT_EQ(tv[i].shape(), new_shape[i]);
EXPECT_EQ(tv[i].type(), DALI_FLOAT);
CompareWithNumber(tv[i], 1 + i * 1.f);
}

auto new_sample_shape = TensorShape<>{10, 10, 3};
new_shape.set_tensor_shape(1, new_sample_shape);

tv.ResizeSample(1, new_sample_shape);

EXPECT_FALSE(tv.IsContiguous());
for (int i = 0; i < 3; i++) {
EXPECT_EQ(tv[i].shape(), new_shape[i]);
}

FillWithNumber(tv[1], 42.f);

EXPECT_EQ(tv[1].shape(), new_sample_shape);
EXPECT_EQ(tv[1].type(), DALI_FLOAT);
CompareWithNumber(tv[1], 42.f);

auto new_smaller_sample_shape = TensorShape<>{5, 5, 3};
new_shape.set_tensor_shape(1, new_smaller_sample_shape);

tv.ResizeSample(1, new_smaller_sample_shape);

EXPECT_FALSE(tv.IsContiguous());
for (int i = 0; i < 3; i++) {
EXPECT_EQ(tv[i].shape(), new_shape[i]);
}

FillWithNumber(tv[1], 42.f);

EXPECT_EQ(tv[1].shape(), new_smaller_sample_shape);
EXPECT_EQ(tv[1].type(), DALI_FLOAT);
CompareWithNumber(tv[1], 42.f);
}


TYPED_TEST(TensorListSuite, ResizeSampleProhibited) {
TensorList<TypeParam> tv;
auto sample_shape = TensorShape<>{10, 10};
tv.SetContiguity(BatchContiguity::Automatic);
EXPECT_THROW(tv.ResizeSample(0, sample_shape), std::runtime_error);
auto shape = TensorListShape<>{{1, 2, 3}, {2, 3, 4}, {3, 4, 5}};
tv.Resize(shape, DALI_FLOAT, BatchContiguity::Contiguous);

EXPECT_THROW(tv.ResizeSample(1, sample_shape), std::runtime_error);
}


TYPED_TEST(TensorListSuite, BreakContiguity) {
TensorList<TypeParam> tv;
// anything goes
Expand Down