Skip to content

Commit

Permalink
Support inferring batch size from tensor argument inputs
Browse files Browse the repository at this point in the history
Signed-off-by: Kamil Tokarski <ktokarski@nvidia.com>
  • Loading branch information
stiepan committed Jan 26, 2023
1 parent bcedcd0 commit 40726d2
Show file tree
Hide file tree
Showing 12 changed files with 34 additions and 47 deletions.
2 changes: 1 addition & 1 deletion dali/operators/generic/constant.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class Constant : public Operator<Backend> {
output_.Reset();
}
output_shape_ = max_output_shape_;
output_shape_.resize(ws.GetRequestedBatchSize(0));
output_shape_.resize(InferBatchSizeFromInput(ws));
output_desc[0] = {output_shape_, output_type_};
return false;
}
Expand Down
5 changes: 3 additions & 2 deletions dali/operators/generic/permute_batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ class PermuteBatchBase : public Operator<Backend> {
this->spec_.TryGetRepeatedArgument(indices_, "indices");
}
// TODO(michalz): Remove when fully variable batch size is supported
DALI_ENFORCE(static_cast<int>(indices_.size()) == ws.GetRequestedBatchSize(0), make_string(
int cur_batch_size = InferBatchSizeFromInput(ws);
DALI_ENFORCE(static_cast<int>(indices_.size()) == cur_batch_size, make_string(
"The number of sample indices ", indices_.size(), " does not match the current batch size, "
"which is ", ws.GetRequestedBatchSize(0)));
"which is ", cur_batch_size));

auto &out_shape = outputs[0].shape;
int D = in_shape.sample_dim();
Expand Down
4 changes: 1 addition & 3 deletions dali/operators/generic/roi_random_crop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@ ROIRandomCropCPU::ROIRandomCropCPU(const OpSpec &spec)

bool ROIRandomCropCPU::SetupImpl(std::vector<OutputDesc> &output_desc,
const Workspace &ws) {
int nsamples = spec_.HasTensorArgument("crop_shape") ?
ws.ArgumentInput("crop_shape").num_samples() :
ws.GetRequestedBatchSize(0);
int nsamples = InferBatchSizeFromInput(ws);
crop_shape_.Acquire(spec_, ws, nsamples, ArgValue_EnforceUniform);
int ndim = crop_shape_[0].shape[0];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class TransformBaseOp : public SequenceOperator<Backend, true> {

bool SetupImpl(std::vector<OutputDesc> &output_descs, const Workspace &ws) override {
has_input_ = ws.NumInput() > 0;
auto curr_batch_size = has_input_ ? ws.GetInputBatchSize(0) : ws.GetRequestedBatchSize(0);
auto curr_batch_size = InferBatchSizeFromInput(ws);
if (has_input_) {
auto &input = ws.Input<Backend>(0);
const auto &shape = input.shape();
Expand Down
2 changes: 1 addition & 1 deletion dali/operators/image/paste/multipaste.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class MultiPasteOp : public Operator<Backend> {
void AcquireArguments(const OpSpec &spec, const Workspace &ws) {
const auto &images = ws.Input<Backend>(0);

auto curr_batch_size = ws.GetRequestedBatchSize(0);
auto curr_batch_size = InferBatchSizeFromInput(ws);
if (curr_batch_size == 0)
return;

Expand Down
26 changes: 1 addition & 25 deletions dali/operators/python_function/dltensor_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ class DLTensorPythonFunctionImpl : public Operator<Backend> {
std::lock_guard<std::mutex> operator_guard(operator_lock);
py::gil_scoped_acquire interpreter_guard{};
py::object output_o = py::none();
auto curr_batch_size = GetCurrBatchSize(ws);
auto curr_batch_size = InferBatchSizeFromInput(ws);
try {
detail::StreamSynchronizer<Backend> sync(ws, synchronize_stream_);
if (batch_processing) {
Expand Down Expand Up @@ -245,30 +245,6 @@ class DLTensorPythonFunctionImpl : public Operator<Backend> {
bool synchronize_stream_;
bool batch_processing;
std::vector<TensorLayout> output_layouts_;

private:
int GetCurrBatchSize(Workspace &ws) {
if (ws.NumInput() > 0) {
auto curr_batch_size = ws.GetInputBatchSize(0);
for (int i = 1; i < ws.NumInput(); i++) {
DALI_ENFORCE(ws.GetInputBatchSize(i) == curr_batch_size,
make_string("Every input shall have the same batch size. Found inconsistent "
"batch sizes (val@idx): ",
ws.GetInputBatchSize(i), "@", i, " vs ", curr_batch_size, "@0."));
}
return curr_batch_size;
} else {
auto curr_batch_size = ws.GetRequestedBatchSize(0);
for (int i = 1; i < ws.NumOutput(); i++) {
DALI_ENFORCE(
ws.GetRequestedBatchSize(i) == curr_batch_size,
make_string("This operator assumes, that requested batch size is the same for every "
"output. Found inconsistent batch sizes (val@idx): ",
ws.GetRequestedBatchSize(i), "@", i, " vs ", curr_batch_size, "@0."));
}
return curr_batch_size;
}
}
};

} // namespace dali
Expand Down
2 changes: 1 addition & 1 deletion dali/operators/random/batch_permutation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ points, that is ``out[i] != i``. This argument is ignored when batch size is 1.)

void BatchPermutation::RunImpl(Workspace &ws) {
auto &output = ws.Output<CPUBackend>(0);
int N = ws.GetRequestedBatchSize(0);
int N = InferBatchSizeFromInput(ws);
if (N < 1)
return;
auto out_view = view<int, 0>(output);
Expand Down
2 changes: 1 addition & 1 deletion dali/operators/random/batch_permutation.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class BatchPermutation : public Operator<CPUBackend> {

bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override {
output_desc.resize(1);
output_desc[0].shape = TensorListShape<0>(ws.GetRequestedBatchSize(0));
output_desc[0].shape = TensorListShape<0>(InferBatchSizeFromInput(ws));
output_desc[0].type = DALI_INT32;
return true;
}
Expand Down
9 changes: 1 addition & 8 deletions dali/operators/random/rng_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,6 @@ class RNGBase : public Operator<Backend> {
return true;
}

int GetBatchSize(const Workspace &ws) const {
if (spec_.NumRegularInput() == 1)
return ws.Input<Backend>(0).shape().size();
else
return ws.GetRequestedBatchSize(0);
}

bool SetupImpl(std::vector<OutputDesc> &output_desc,
const Workspace &ws) override {
if (IsNoiseGen)
Expand All @@ -61,7 +54,7 @@ class RNGBase : public Operator<Backend> {

bool has_shape = spec_.ArgumentDefined("shape");
bool has_shape_like = spec_.NumRegularInput() == 1;
int nsamples = GetBatchSize(ws);
int nsamples = InferBatchSizeFromInput(ws);
DALI_ENFORCE(!(has_shape && has_shape_like),
"Providing argument \"shape\" is incompatible with providing a shape-like input");

Expand Down
2 changes: 1 addition & 1 deletion dali/operators/random/uniform.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Uniform : public Operator<CPUBackend> {

bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override {
output_desc.resize(1);
auto curr_batch_size = ws.GetRequestedBatchSize(0);
auto curr_batch_size = InferBatchSizeFromInput(ws);
output_desc[0].type = DALI_FLOAT;
if (spec_.ArgumentDefined("shape")) {
GetShapeArgument(output_desc[0].shape, spec_, "shape", ws, curr_batch_size);
Expand Down
6 changes: 3 additions & 3 deletions dali/pipeline/operator/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ void OperatorBase::EnforceUniformInputBatchSize(const Workspace &ws) const {
return;
}
#endif
auto curr_batch_size = ws.NumInput() > 0 ? ws.GetInputBatchSize(0) : ws.GetRequestedBatchSize(0);
auto curr_batch_size = InferBatchSizeFromInput(ws);
for (int i = 0; i < ws.NumInput(); i++) {
DALI_ENFORCE(curr_batch_size == ws.GetInputBatchSize(i),
"Batch size has to be uniform across one iteration.");
Expand All @@ -60,7 +60,7 @@ void OperatorBase::EnforceUniformOutputBatchSize(const Workspace &ws) const {
return;
}
#endif
auto ref_batch_size = ws.NumInput() > 0 ? ws.GetInputBatchSize(0) : ws.GetRequestedBatchSize(0);
auto ref_batch_size = InferBatchSizeFromInput(ws);
for (int i = 0; i < ws.NumOutput(); i++) {
auto output_batch_size = ws.Output<Backend>(i).shape().num_samples();
DALI_ENFORCE(ref_batch_size == output_batch_size,
Expand All @@ -73,7 +73,7 @@ void OperatorBase::EnforceUniformOutputBatchSize(const Workspace &ws) const {
template <>
void OperatorBase::EnforceUniformOutputBatchSize<MixedBackend>(
const Workspace &ws) const {
auto ref_batch_size = ws.NumInput() > 0 ? ws.GetInputBatchSize(0) : ws.GetRequestedBatchSize(0);
auto ref_batch_size = InferBatchSizeFromInput(ws);
for (int i = 0; i < ws.NumOutput(); i++) {
auto output_batch_size = const_cast<Workspace &>(ws)
.Output<GPUBackend>(i)
Expand Down
19 changes: 19 additions & 0 deletions dali/pipeline/operator/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,25 @@ inline TensorLayout GetOutputLayout(const Workspace &ws, int i) {
return out.GetLayout();
}

/**
* @brief Takes batch size from any of the available tensor inputs.
*
* If the inputs are not available, the global ouput batch size set
* by executor is used. It does not perform the check if all the
* tensor inputs have matching number of samples.
*/
inline int InferBatchSizeFromInput(const Workspace &ws) {
if (ws.NumInput() > 0) {
return ws.GetInputBatchSize(0);
}
const ArgumentWorkspace &argument_ws = ws;
if (begin(argument_ws) != end(argument_ws)) {
auto [name, arg] = *begin(argument_ws);
return arg.tvec->num_samples();
}
return ws.GetRequestedBatchSize(0);
}

template <>
class Operator<CPUBackend> : public OperatorBase {
public:
Expand Down

0 comments on commit 40726d2

Please sign in to comment.