From 7403586b13f76b93fff7412eef91873b13bed77a Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Tue, 31 Jan 2023 12:07:17 +0100 Subject: [PATCH 1/3] Set requested batch size based on the op tensor arguments if avialbale Signed-off-by: Kamil Tokarski --- dali/pipeline/executor/executor.cc | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/dali/pipeline/executor/executor.cc b/dali/pipeline/executor/executor.cc index 74e4d22cb6b..70091ba4fda 100644 --- a/dali/pipeline/executor/executor.cc +++ b/dali/pipeline/executor/executor.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -30,6 +30,24 @@ namespace dali { +/** + * @brief Takes the batch size from any of the op's tensor inputs. + * + * If no inputs were specified, a batch size inferred from + * the stage queue is used instead. + */ +inline int InferBatchSizeFromInput(const Workspace &ws, int stage_batch_size) { + if (ws.NumInput() > 0) { + return ws.GetInputBatchSize(0); + } + const ArgumentWorkspace &argument_ws = ws; + if (begin(argument_ws) != end(argument_ws)) { + auto [name, arg] = *begin(argument_ws); + return arg.tvec->num_samples(); + } + return stage_batch_size; +} + template void Executor::PreRun() { auto batch_size = InferBatchSize(batch_size_providers_); @@ -97,7 +115,7 @@ void Executor::RunCPUImpl(size_t iteration_id) { return; } - auto batch_size = batch_sizes_cpu_.front(); + int stage_batch_size = batch_sizes_cpu_.front(); batch_sizes_cpu_.pop(); // Run the cpu-ops in the thread @@ -106,6 +124,7 @@ void Executor::RunCPUImpl(size_t iteration_id) { OpNode &op_node = graph_->Node(OpType::CPU, cpu_op_id); decltype(auto) ws = ws_policy_.template GetWorkspace(cpu_idxs, *graph_, cpu_op_id); + int batch_size = InferBatchSizeFromInput(ws, stage_batch_size); ws.SetBatchSizes(batch_size); DomainTimeRange tr("[DALI][CPU op] " + op_node.instance_name, DomainTimeRange::kBlue1); @@ -143,7 +162,7 @@ void Executor::RunMixedImpl(size_t iteration_id) { if (device_id_ != CPU_ONLY_DEVICE_ID) CUDA_CALL(cudaEventSynchronize(mixed_stage_event_)); - auto batch_size = batch_sizes_mixed_.front(); + int stage_batch_size = batch_sizes_mixed_.front(); batch_sizes_mixed_.pop(); for (int i = 0; i < graph_->NumOp(OpType::MIXED) && !exec_error_; ++i) { @@ -151,6 +170,7 @@ void Executor::RunMixedImpl(size_t iteration_id) { try { decltype(auto) ws = ws_policy_.template GetWorkspace(mixed_idxs, *graph_, i); + int batch_size = InferBatchSizeFromInput(ws, stage_batch_size); ws.SetBatchSizes(batch_size); DomainTimeRange tr("[DALI][Mixed op] " + op_node.instance_name, DomainTimeRange::kOrange); @@ -208,7 +228,7 @@ void Executor::RunGPUImpl(size_t iteration_id) { // iterations of a stage of the pipeline. CUDA_CALL(cudaEventSynchronize(gpu_stage_event_)); - auto batch_size = batch_sizes_gpu_.front(); + int stage_batch_size = batch_sizes_gpu_.front(); batch_sizes_gpu_.pop(); for (int i = 0; i < graph_->NumOp(OpType::GPU) && !exec_error_; ++i) { @@ -216,6 +236,7 @@ void Executor::RunGPUImpl(size_t iteration_id) { try { decltype(auto) ws = ws_policy_.template GetWorkspace(gpu_idxs, *graph_, i); + int batch_size = InferBatchSizeFromInput(ws, stage_batch_size); ws.SetBatchSizes(batch_size); auto parent_events = ws.ParentEvents(); From 76a3e4290fd67b50abda51afe52a3b9c6b318a41 Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Tue, 31 Jan 2023 14:39:57 +0100 Subject: [PATCH 2/3] Include arg inputs in the executor's check for all empty batch Signed-off-by: Kamil Tokarski --- dali/pipeline/executor/executor.cc | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/dali/pipeline/executor/executor.cc b/dali/pipeline/executor/executor.cc index 70091ba4fda..df19fb8595d 100644 --- a/dali/pipeline/executor/executor.cc +++ b/dali/pipeline/executor/executor.cc @@ -30,18 +30,26 @@ namespace dali { +inline bool HasTensorArgInputs(const ArgumentWorkspace& argument_ws) { + return begin(argument_ws) != end(argument_ws); +} + /** * @brief Takes the batch size from any of the op's tensor inputs. * * If no inputs were specified, a batch size inferred from * the stage queue is used instead. + * + * Assumes that most of the operators expect uniform batch + * size between all inputs and outputs. The notable exception + * of split and merge operators cannot rely on this value. */ inline int InferBatchSizeFromInput(const Workspace &ws, int stage_batch_size) { if (ws.NumInput() > 0) { return ws.GetInputBatchSize(0); } const ArgumentWorkspace &argument_ws = ws; - if (begin(argument_ws) != end(argument_ws)) { + if (HasTensorArgInputs(argument_ws)) { auto [name, arg] = *begin(argument_ws); return arg.tvec->num_samples(); } @@ -353,11 +361,15 @@ void Executor::RunHelper(OpNode &op_node, Workspac } // Assuming that most operators don't expect empty input, and expect consistent input. - if (ws.NumInput() > 0) { + if (ws.NumInput() > 0 || HasTensorArgInputs(ws)) { bool all_inputs_empty = true; for (int i = 0; i < ws.NumInput(); i++) { all_inputs_empty = all_inputs_empty && ws.GetInputBatchSize(i) == 0; } + const ArgumentWorkspace &argument_ws = ws; + for (const auto &[name, arg] : argument_ws) { + all_inputs_empty = all_inputs_empty && arg.tvec->num_samples() == 0; + } if (all_inputs_empty) { // We skip the execution of this operator and Reset the outputs in case some state was still // present. From 4d7b85f54525387fe03016f97a3a8b83d3c870d2 Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Tue, 31 Jan 2023 15:54:48 +0100 Subject: [PATCH 3/3] Add tests to ops that may infer batch size from arg input Signed-off-by: Kamil Tokarski --- .../test_pipeline_conditionals.py | 132 ++++++++++++++++++ 1 file changed, 132 insertions(+) diff --git a/dali/test/python/conditionals/test_pipeline_conditionals.py b/dali/test/python/conditionals/test_pipeline_conditionals.py index b9e18149608..ebcedc717b0 100644 --- a/dali/test/python/conditionals/test_pipeline_conditionals.py +++ b/dali/test/python/conditionals/test_pipeline_conditionals.py @@ -617,3 +617,135 @@ def one_return(): " (both `if` branches). The `else` branch must also have a return" " statement.")): one_return() + + +def _tensor_arg_permute_batch_params(): + batch_sizes = [1, 5, 8] + inp0 = [[np.full((2, 2), i, dtype=np.float32) for i in range(batch_size)] + for batch_size in batch_sizes] + mask_batches = [ + np.array([i % 2 for i in range(batch_size)], dtype=bool) for batch_size in batch_sizes + ] + kwarg_batches = [np.array([pred for pred in mask], dtype=np.int32) for mask in mask_batches] + return (inp0, ), mask_batches, {'indices': kwarg_batches} + + +def _tensor_arg_transform_per_dim_params(arg_name): + + def inner(): + batch_sizes = [5, 1, 2, 8] + mask_batches = [ + np.array([i % 2 for i in range(batch_size)], dtype=bool) for batch_size in batch_sizes + ] + kwarg_batches = [ + np.array([[pred, pred] for pred in mask], dtype=np.float32) for mask in mask_batches + ] + return tuple(), mask_batches, {arg_name: kwarg_batches} + + return inner + + +def _tensor_arg_rotate_params(): + batch_sizes = [3, 1, 2, 4] + mask_batches = [ + np.array([i % 2 for i in range(batch_size)], dtype=bool) for batch_size in batch_sizes + ] + kwarg_batches = [ + np.array([10 + 45 * pred for pred in mask], dtype=np.float32) for mask in mask_batches + ] + return tuple(), mask_batches, {'angle': kwarg_batches} + + +def _tensor_arg_roi_random_crop_params(): + batch_sizes = [1, 2, 7, 3] + crop_shape = [[ + np.array([100 * i + 50, 200 * i + 50, 3], dtype=np.int32) for i in range(batch_size) + ] for batch_size in batch_sizes] + roi_start = [[ + np.array([sample[0] // 2, sample[1] // 2, sample[2]], dtype=np.int32) for sample in batch + ] for batch in crop_shape] + mask_batches = [ + np.array([i % 2 for i in range(batch_size)], dtype=bool) for batch_size in batch_sizes + ] + return tuple(), mask_batches, { + 'crop_shape': crop_shape, + 'roi_start': roi_start, + 'roi_end': crop_shape + } + + +def _tensor_arg_shape_kwarg(): + batch_sizes = [1, 2, 3, 16, 5] + shape = [[np.array([1 + 3 * i, 2 * (i + 1) - 1], dtype=np.int32) for i in range(batch_size)] + for batch_size in batch_sizes] + mask_batches = [ + np.array([i % 2 for i in range(batch_size)], dtype=bool) for batch_size in batch_sizes + ] + return tuple(), mask_batches, {'shape': shape} + + +# Test operators that infer their batch sizes from the tensor argument inputs +@params(fn.permute_batch, fn.roi_random_crop, fn.transforms.crop, fn.transforms.scale, + fn.transforms.shear, fn.transforms.translation, fn.transforms.rotation, + fn.random.uniform, fn.random.normal, fn.random.coin_flip) +def test_named_tensor_arguments(op): + + ops2params = { + fn.permute_batch: _tensor_arg_permute_batch_params, + fn.roi_random_crop: _tensor_arg_roi_random_crop_params, + fn.transforms.crop: _tensor_arg_transform_per_dim_params('from_start'), + fn.transforms.scale: _tensor_arg_transform_per_dim_params('scale'), + fn.transforms.shear: _tensor_arg_transform_per_dim_params('angles'), + fn.transforms.translation: _tensor_arg_transform_per_dim_params('offset'), + fn.transforms.rotation: _tensor_arg_rotate_params, + fn.random.uniform: _tensor_arg_shape_kwarg, + fn.random.normal: _tensor_arg_shape_kwarg, + fn.random.coin_flip: _tensor_arg_shape_kwarg, + } + + def dummy_source(batches): + + def cb(): + for batch in batches: + yield batch + + return cb + + def get_pipeline(op, args_batches, mask_batches, kwargs_batches, num_threads=4, device_id=0): + max_batch_size = max(len(batch) for batch in mask_batches) + + @pipeline_def(batch_size=max_batch_size, num_threads=num_threads, device_id=device_id) + def split_pipeline(): + args = [fn.external_source(dummy_source(arg_batches)) for arg_batches in args_batches] + mask = fn.external_source(dummy_source(mask_batches)) + kwargs = { + kwarg_name: fn.external_source(dummy_source(batches)) + for kwarg_name, batches in kwargs_batches.items() + } + kwargs_split = { + kwarg_name: fn._conditional.split(batch, predicate=mask) + for kwarg_name, batch in kwargs.items() + } + split_args = [fn._conditional.split(arg, predicate=mask) for arg in args] + left_args = [left_arg for left_arg, _ in split_args] + right_args = [right_arg for _, right_arg in split_args] + left = op( + *left_args, + **{kwarg_name: left_kwarg + for kwarg_name, (left_kwarg, _) in kwargs_split.items()}) + right = op( + *right_args, **{ + kwarg_name: right_kwarg + for kwarg_name, (_, right_kwarg) in kwargs_split.items() + }) + batch = fn._conditional.merge(left, right, predicate=mask) + return batch + + return split_pipeline() + + args_batches, mask_batches, kwargs_batches = ops2params[op]() + pipe = get_pipeline(op=op, args_batches=args_batches, mask_batches=mask_batches, + kwargs_batches=kwargs_batches) + pipe.build() + for _ in range(len(mask_batches)): + pipe.run()