Skip to content

Commit 0cd1cf1

Browse files
authored
Ensure sample encapsulation in Tensor Vector (#3701)
Add APIs matching TensorList to TensorVector: * sample pointer accessors * Set/GetMeta Change operator[] to return [Const]SampleView. Introduce UnsafeSetSample and UnsafeCopySample to replace TensorVector[i].ShareData(tensor) and TensorVector[i].Copy(tensor) - they work with current code base, but for proper sample-based data structure more checks should be introduced - intended for follow up. Adjust code where necessary: * where possible use data accessors directly on the TensorVector instead of the sample, as it should be faster than create temporary, so: `tv[i].mutable_data<T>()` -> `tv.mutable_tensor<T>(i)` etc. * Using SampleViews is compatible with code that uses `view<T>`, as `view<T>(Tensor)` is equivalent to `view<T>(sample_view(Tensor))` Adjustments: * allow views to work with scalar Tensors (they treated them as empty) * introduce distinct SampleView and ConstSampleView as they need to be returned by value and we need sensible overloads for `view<>`. * allow to access `capacity` and `nbytes` of individual samples, introduce _chunks_capacity and _chunks_nbytes for that. Next steps written as TODO in TensorVector dosctring. Current naming: The `Unsafe` prefix in SetSample and CopySample is intended to temporary stay there to discourage introduction of new use cases till the followup introduces remaining checks. Capacity and nbytes of individual allocations have leading underscore as the API is to be reworked and is not intended for new usages. Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
1 parent 8e36311 commit 0cd1cf1

File tree

72 files changed

+992
-402
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+992
-402
lines changed

dali/benchmark/displacement_cpu_bench.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2017-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -88,7 +88,7 @@ void DisplacementBench(benchmark::State& st) {//NOLINT
8888
// tensor out is resized by operator itself in DisplacementFilter::DataDependentSetup()
8989

9090
// TODO(klecki) Accomodate to use different inputs from test data
91-
auto *ptr = (*tensor_in)[0].template mutable_data<T>();
91+
auto *ptr = (*tensor_in).template mutable_tensor<T>(0);
9292
for (int i = 0; i < N; i++) {
9393
ptr[i] = i;
9494
}

dali/benchmark/operator_bench.h

+6-9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -54,16 +54,13 @@ class OperatorBench : public DALIBenchmark {
5454
auto op_ptr = InstantiateOperator(op_spec);
5555

5656
auto data_in = std::make_shared<TensorVector<CPUBackend>>(batch_size);
57-
for (auto &in_ptr : *data_in) {
58-
in_ptr = std::make_shared<Tensor<CPUBackend>>();
59-
in_ptr->set_type<T>();
60-
in_ptr->Resize({H, W, C});
61-
in_ptr->SetLayout("HWC");
62-
}
57+
data_in->set_type<T>();
58+
data_in->Resize(uniform_list_shape(batch_size, TensorShape<>{H, W, C}));
59+
data_in->SetLayout("HWC");
6360

6461
if (fill_in_data) {
65-
for (auto &in_ptr : *data_in) {
66-
auto *ptr = in_ptr->template mutable_data<T>();
62+
for (int sample_idx = 0; sample_idx < batch_size; sample_idx++) {
63+
auto *ptr = data_in->template mutable_tensor<T>(sample_idx);
6764
for (int i = 0; i < N; i++) {
6865
ptr[i] = static_cast<T>(i);
6966
}

dali/c_api/c_api.cc

+10-6
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ void SetExternalInputTensors(daliPipelineHandle *pipe_handle, const char *name,
129129
if (layout_str != nullptr) {
130130
layout = dali::TensorLayout(layout_str);
131131
}
132-
dali::TensorVector<Backend> data(curr_batch_size);
133132
auto type_id = static_cast<dali::DALIDataType>(data_type);
134133
auto elem_sizeof = dali::TypeTable::GetTypeInfo(type_id).size();
135134

@@ -139,15 +138,20 @@ void SetExternalInputTensors(daliPipelineHandle *pipe_handle, const char *name,
139138
else
140139
order = AccessOrder::host();
141140

141+
dali::TensorVector<Backend> data(curr_batch_size);
142+
data.set_pinned(flags & DALI_ext_pinned);
143+
data.set_sample_dim(sample_dim);
144+
data.set_type(type_id);
145+
data.set_order(order);
146+
data.SetLayout(layout);
147+
142148
for (int i = 0; i < curr_batch_size; i++) {
143149
// We cast away the const from data_ptr, as there is no other way of passing it to the
144150
// Tensor as we must also set the shape and type metadata.
145151
// The vector that we pass to pipeline is const.
146-
data[i].set_pinned(flags & DALI_ext_pinned);
147-
data[i].set_order(order);
148-
data[i].ShareData(const_cast<void *>(data_ptr[i]), tl_shape[i].num_elements() * elem_sizeof);
149-
data[i].Resize(tl_shape[i], type_id);
150-
data[i].SetLayout(layout);
152+
std::shared_ptr<void> ptr(const_cast<void *>(data_ptr[i]), [](void *){}); // no deleter
153+
data.UnsafeSetSample(i, ptr, tl_shape[i].num_elements() * elem_sizeof, flags & DALI_ext_pinned,
154+
tl_shape[i], type_id, order, layout);
151155
}
152156
pipeline->SetExternalInput(name, data, order,
153157
flags & DALI_ext_force_sync,

dali/operators/audio/nonsilence_op.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -228,8 +228,8 @@ class NonsilenceOperatorCpu : public NonsilenceOperator<CPUBackend> {
228228
args.reset_interval = reset_interval_;
229229

230230
auto res = DetectNonsilenceRegion(intermediate_buffers_[thread_id], args);
231-
auto beg_ptr = output_begin[sample_id].mutable_data<int>();
232-
auto len_ptr = output_length[sample_id].mutable_data<int>();
231+
auto *beg_ptr = output_begin.mutable_tensor<int>(sample_id);
232+
auto *len_ptr = output_length.mutable_tensor<int>(sample_id);
233233
*beg_ptr = res.first;
234234
*len_ptr = res.second;
235235
}, in_shape.tensor_size(sample_id));

dali/operators/audio/preemphasis_filter_op.cc

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -65,11 +65,11 @@ void PreemphasisFilterCPU::RunImplTyped(workspace_t<CPUBackend> &ws) {
6565
for (int sample_id = 0; sample_id < nsamples; sample_id++) {
6666
tp.AddWork(
6767
[this, &output, &input, sample_id](int thread_id) {
68-
const auto in_ptr = input[sample_id].data<InputType>();
69-
auto out_ptr = output[sample_id].mutable_data<OutputType>();
70-
DALI_ENFORCE(input[sample_id].shape() == output[sample_id].shape(),
68+
const auto *in_ptr = input.tensor<InputType>(sample_id);
69+
auto *out_ptr = output.mutable_tensor<OutputType>(sample_id);
70+
DALI_ENFORCE(input.tensor_shape(sample_id) == output.tensor_shape(sample_id),
7171
"Input and output shapes don't match");
72-
auto n = volume(output[sample_id].shape());
72+
auto n = volume(output.tensor_shape(sample_id));
7373
auto coeff = preemph_coeff_[sample_id];
7474
if (coeff == 0.0f) {
7575
for (int64_t j = 0; j < n; j++) {

dali/operators/decoder/audio/audio_decoder_op.cc

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -88,13 +88,13 @@ AudioDecoderCpu::SetupImpl(std::vector<OutputDesc> &output_desc, const workspace
8888

8989
for (int i = 0; i < batch_size; i++) {
9090
auto &meta = sample_meta_[i] =
91-
decoders_[i]->Open({static_cast<const char *>(input[i].raw_data()),
92-
input[i].shape().num_elements()});
91+
decoders_[i]->Open({static_cast<const char *>(input.raw_tensor(i)),
92+
input.tensor_shape(i).num_elements()});
9393
TensorShape<> data_sample_shape = DecodedAudioShape(
9494
meta, use_resampling_ ? target_sample_rates_[i] : -1.0f, downmix_);
9595
shape_data.set_tensor_shape(i, data_sample_shape);
9696
shape_rate.set_tensor_shape(i, {});
97-
files_names_[i] = input[i].GetSourceInfo();
97+
files_names_[i] = input.GetMeta(i).GetSourceInfo();
9898
}
9999

100100
output_desc[0] = { shape_data, output_type_ };

dali/operators/decoder/decoder_test.h

+14-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@
1818
#include <string>
1919
#include <vector>
2020
#include <memory>
21+
#include "dali/pipeline/data/types.h"
2122
#include "dali/test/dali_test_decoder.h"
2223

2324
namespace dali {
@@ -64,6 +65,7 @@ class DecodeTestBase : public GenericDecoderTest<ImgType> {
6465
// single input - encoded images
6566
// single output - decoded images
6667
TensorVector<CPUBackend> out(inputs[0]->num_samples());
68+
std::vector<Tensor<CPUBackend>> tmp_out(inputs[0]->num_samples());
6769
const TensorList<CPUBackend> &encoded_data = *inputs[0];
6870
const int c = this->GetNumColorComp();
6971

@@ -72,7 +74,17 @@ class DecodeTestBase : public GenericDecoderTest<ImgType> {
7274
auto data_size = volume(encoded_data.tensor_shape(i));
7375
this->DecodeImage(
7476
data, data_size, c, this->ImageType(),
75-
&out[i], GetCropWindowGenerator(i));
77+
&tmp_out[i], GetCropWindowGenerator(i));
78+
}
79+
80+
TensorListShape<> out_shape(inputs[0]->num_samples(), 3);
81+
for (size_t i = 0; i < encoded_data.num_samples(); ++i) {
82+
out_shape.set_tensor_shape(i, tmp_out[i].shape());
83+
}
84+
out.SetupLike(tmp_out[0]);
85+
out.Resize(out_shape, DALI_UINT8);
86+
for (size_t i = 0; i < encoded_data.num_samples(); ++i) {
87+
out.UnsafeSetSample(i, tmp_out[i]);
7688
}
7789

7890
vector<std::shared_ptr<TensorList<CPUBackend>>> outputs;

dali/operators/decoder/nvjpeg/nvjpeg_decoder_decoupled_api.h

+22-16
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -554,15 +554,16 @@ class nvJPEGDecoder : public Operator<MixedBackend>, CachedDecoderImpl {
554554
samples_jpeg2k_.clear();
555555
#endif // NVJPEG2K_ENABLED
556556

557+
const auto &input = ws.Input<CPUBackend>(0);
557558
for (int i = 0; i < curr_batch_size; i++) {
558-
const auto &in = ws.Input<CPUBackend>(0)[i];
559-
const auto in_size = in.size();
560-
thread_pool_.AddWork([this, i, &in, in_size](int tid) {
561-
auto *input_data = in.data<uint8_t>();
559+
auto *input_data = input.tensor<uint8_t>(i);
560+
const auto in_size = input.tensor_shape(i).num_elements();
561+
const auto &source_info = input.GetMeta(i).GetSourceInfo();
562+
thread_pool_.AddWork([this, i, input_data, in_size, source_info](int tid) {
562563
SampleData &data = sample_data_[i];
563564
data.clear();
564565
data.sample_idx = i;
565-
data.file_name = in.GetSourceInfo();
566+
data.file_name = source_info;
566567
data.encoded_length = in_size;
567568

568569
auto cached_shape = CacheImageShape(data.file_name);
@@ -704,15 +705,17 @@ class nvJPEGDecoder : public Operator<MixedBackend>, CachedDecoderImpl {
704705

705706
void ProcessImagesCuda(MixedWorkspace &ws) {
706707
auto& output = ws.Output<GPUBackend>(0);
708+
const auto &input = ws.Input<CPUBackend>(0);
707709
for (auto *sample : samples_single_) {
708710
assert(sample);
709711
auto i = sample->sample_idx;
710712
auto *output_data = output.mutable_tensor<uint8_t>(i);
711-
const auto &in = ws.Input<CPUBackend>(0)[i];
713+
const auto *in_data = input.tensor<uint8_t>(i);
714+
const auto in_size = input.tensor_shape(i).num_elements();
712715
thread_pool_.AddWork(
713-
[this, sample, &in, output_data](int tid) {
714-
SampleWorker(sample->sample_idx, sample->file_name, in.size(), tid,
715-
in.data<uint8_t>(), output_data, streams_[tid]);
716+
[this, sample, in_data, in_size, output_data](int tid) {
717+
SampleWorker(sample->sample_idx, sample->file_name, in_size, tid,
718+
in_data, output_data, streams_[tid]);
716719
}, task_priority_seq_--); // FIFO order, since the samples were already ordered
717720
}
718721
}
@@ -808,15 +811,17 @@ class nvJPEGDecoder : public Operator<MixedBackend>, CachedDecoderImpl {
808811
}
809812

810813
void ProcessImagesHost(MixedWorkspace &ws) {
814+
const auto &input = ws.Input<CPUBackend>(0);
811815
auto& output = ws.Output<GPUBackend>(0);
812816
for (auto *sample : samples_host_) {
813817
auto i = sample->sample_idx;
818+
const auto *input_data = input.tensor<uint8_t>(i);
819+
auto in_size = input.tensor_shape(i).num_elements();
814820
auto *output_data = output.mutable_tensor<uint8_t>(i);
815-
const auto &in = ws.Input<CPUBackend>(0)[i];
816821
ImageCache::ImageShape shape = output_shape_[i].to_static<3>();
817822
thread_pool_.AddWork(
818-
[this, sample, &in, output_data, shape](int tid) {
819-
HostFallback<StorageGPU>(in.data<uint8_t>(), in.size(), output_image_type_, output_data,
823+
[this, sample, input_data, in_size, output_data, shape](int tid) {
824+
HostFallback<StorageGPU>(input_data, in_size, output_image_type_, output_data,
820825
streams_[tid], sample->file_name, sample->roi, use_fast_idct_);
821826
CacheStore(sample->file_name, output_data, shape, streams_[tid]);
822827
}, task_priority_seq_--); // FIFO order, since the samples were already ordered
@@ -846,13 +851,14 @@ class nvJPEGDecoder : public Operator<MixedBackend>, CachedDecoderImpl {
846851
int j = 0;
847852
TensorVector<CPUBackend> tv(samples_hw_batched_.size());
848853

854+
const auto &input = ws.Input<CPUBackend>(0);
855+
tv.SetupLike(input);
849856
for (auto *sample : samples_hw_batched_) {
850857
int i = sample->sample_idx;
851-
const auto &in = ws.Input<CPUBackend>(0)[i];
852858
const auto &out_shape = output_shape_.tensor_shape(i);
853859

854-
tv[j].ShareData(const_cast<Tensor<CPUBackend> &>(in));
855-
in_lengths_[j] = in.size();
860+
tv.UnsafeSetSample(j, input, i);
861+
in_lengths_[j] = input.tensor_shape(i).num_elements();
856862
nvjpeg_destinations_[j].channel[0] = output.mutable_tensor<uint8_t>(i);
857863
nvjpeg_destinations_[j].pitch[0] = out_shape[1] * out_shape[2];
858864
nvjpeg_params_[j] = sample->params;

dali/operators/generic/cast.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ void CastCPU::RunImpl(HostWorkspace &ws) {
5151
TYPE_SWITCH(itype, type2id, IType, CAST_ALLOWED_TYPES, (
5252

5353
for (int sample_id = 0; sample_id < num_samples; sample_id++) {
54-
auto *out = output[sample_id].mutable_data<OType>();
55-
const auto *in = input[sample_id].data<IType>();
54+
auto *out = output.mutable_tensor<OType>(sample_id);
55+
const auto *in = input.tensor<IType>(sample_id);
5656
auto size = input_shape.tensor_size(sample_id);
5757
tp.AddWork([out, in, size](int thread_id) { CpuHelper<OType, IType>(out, in, size); },
5858
size);

dali/operators/generic/constant.cc

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -80,7 +80,7 @@ void FillTensorVector(
8080
assert(is_uniform(shape));
8181
int64_t n = shape[0].num_elements();
8282
assert(src.size() == static_cast<size_t>(n) || src.size() == 1);
83-
Dst *out = dst[0].mutable_data<Dst>();
83+
Dst *out = dst.mutable_tensor<Dst>(0);
8484
if (src.size() == 1) {
8585
Dst val = ConvertSat<Dst>(src[0]);
8686
for (int64_t i = 0; i < n; i++) {
@@ -92,7 +92,7 @@ void FillTensorVector(
9292
}
9393
}
9494
for (int i = 1; i < shape.num_samples(); i++) {
95-
dst[i].ShareData(dst[0]);
95+
dst.UnsafeSetSample(i, dst, 0);
9696
}
9797
}
9898
} // namespace
@@ -116,7 +116,7 @@ void Constant<CPUBackend>::RunImpl(HostWorkspace &ws) {
116116
out.Resize(output_shape_);
117117
int N = output_shape_.num_samples();
118118
for (int i = 0; i < N; i++) {
119-
assert(out[i].raw_data() == output_[i].raw_data());
119+
assert(out.raw_tensor(i) == output_.raw_tensor(i));
120120
}
121121
out.SetLayout(layout_);
122122
}

dali/operators/generic/erase/erase_utils.h

+9-9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
1+
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -95,17 +95,17 @@ std::vector<kernels::EraseArgs<T, Dims>> GetEraseArgs(const OpSpec &spec,
9595

9696
for (int i = 0; i < nsamples; i++) {
9797
if (has_tensor_roi_anchor) {
98-
const auto& anchor = ws.ArgumentInput("anchor")[i];
99-
assert(anchor.size() > 0);
100-
roi_anchor.resize(anchor.size());
101-
std::memcpy(roi_anchor.data(), anchor.data<float>(), sizeof(float) * roi_anchor.size());
98+
auto anchor = view<const float>(ws.ArgumentInput("anchor")[i]);
99+
assert(anchor.shape.num_elements() > 0);
100+
roi_anchor.resize(anchor.shape.num_elements());
101+
std::memcpy(roi_anchor.data(), anchor.data, sizeof(float) * roi_anchor.size());
102102
}
103103

104104
if (has_tensor_roi_shape) {
105-
const auto& shape = ws.ArgumentInput("shape")[i];
106-
assert(shape.size() > 0);
107-
roi_shape.resize(shape.size());
108-
std::memcpy(roi_shape.data(), shape.data<float>(), sizeof(float) * roi_shape.size());
105+
auto shape = view<const float>(ws.ArgumentInput("shape")[i]);
106+
assert(shape.shape.num_elements() > 0);
107+
roi_shape.resize(shape.num_elements());
108+
std::memcpy(roi_shape.data(), shape.data, sizeof(float) * roi_shape.size());
109109
}
110110

111111
DALI_ENFORCE(roi_anchor.size() == roi_shape.size());

dali/operators/generic/lookup_table.cc

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -25,8 +25,8 @@ void LookupValuesImpl(ThreadPool &tp, TensorVector<CPUBackend> &output,
2525
const Output *lookup_table, const Output default_value) {
2626
for (int sample_idx = 0; sample_idx < shape.num_samples(); sample_idx++) {
2727
auto data_size = shape.tensor_size(sample_idx);
28-
auto *out_data = output[sample_idx].mutable_data<Output>();
29-
const auto *in_data = input[sample_idx].data<Input>();
28+
auto *out_data = output.mutable_tensor<Output>(sample_idx);
29+
const auto *in_data = input.tensor<Input>(sample_idx);
3030
tp.AddWork(
3131
[=](int thread_id) {
3232
for (int64_t i = 0; i < data_size; i++) {

dali/operators/generic/permute_batch.cc

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -45,7 +45,8 @@ void PermuteBatch<CPUBackend>::RunImpl(HostWorkspace &ws) {
4545
int src = indices_[i];
4646
tp.AddWork([&, i, src](int tid) {
4747
output.SetMeta(i, input.GetMeta(i));
48-
output[i].Copy(input[src]);
48+
// TODO(klecki): SetSample
49+
output.UnsafeCopySample(i, input, src);
4950
}, size);
5051
}
5152
tp.RunAll();

dali/operators/generic/reshape.cc

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -393,8 +393,8 @@ void Reshape<CPUBackend>::RunImpl(HostWorkspace &ws) {
393393
out.Resize(output_shape_, output_type_->id());
394394
int N = output_shape_.num_samples();
395395
for (int i = 0; i < N; i++) {
396-
assert(out[i].raw_data() == in[i].raw_data());
397-
assert(out[i].shape() == output_shape_[i]);
396+
assert(out.raw_tensor(i) == in.raw_tensor(i));
397+
assert(out.tensor_shape(i) == output_shape_[i]);
398398
}
399399
out.SetLayout(layout);
400400
}

0 commit comments

Comments
 (0)