diff --git a/dali/c_api/c_api.cc b/dali/c_api/c_api.cc index 8b5c9b0c7bf..4de753308a5 100644 --- a/dali/c_api/c_api.cc +++ b/dali/c_api/c_api.cc @@ -31,7 +31,6 @@ void daliCreatePipeline(daliPipelineHandle* pipe_handle, batch_size, num_threads, device_id, - -1, true, true); pipe->Build(); @@ -51,19 +50,6 @@ void daliOutput(daliPipelineHandle* pipe_handle) { pipeline->Outputs(ws); } -void* daliTensorAt(daliPipelineHandle* pipe_handle, int n) { - dali::DeviceWorkspace* ws = reinterpret_cast(pipe_handle->ws); - if (ws->OutputIsType(n)) { - dali::Tensor *t = new dali::Tensor(); - t->ShareData(ws->Output(n)); - return t; - } else { - dali::Tensor *t = new dali::Tensor(); - t->ShareData(ws->Output(n)); - return t; - } -} - int64_t* daliShapeAt(daliPipelineHandle* pipe_handle, int n) { dali::DeviceWorkspace* ws = reinterpret_cast(pipe_handle->ws); int64_t* c_shape = nullptr; diff --git a/dali/c_api/c_api.h b/dali/c_api/c_api.h index eeb0f0a00bb..4794ab0c13c 100644 --- a/dali/c_api/c_api.h +++ b/dali/c_api/c_api.h @@ -26,17 +26,45 @@ extern "C" { void* ws; }; + /** + * @brief Create DALI pipeline. Setting batch_size, + * num_threads or device_id here overrides + * values stored in the serialized pipeline. + */ DLL_PUBLIC void daliCreatePipeline(daliPipelineHandle* pipe_handle, const char *serialized_pipeline, int length, - int batch_size, - int num_threads, - int device_id); + int batch_size = -1, + int num_threads = -1, + int device_id = -1); + + /** + * @brief Start the execution of the pipeline. + */ DLL_PUBLIC void daliRun(daliPipelineHandle* pipe_handle); + + /** + * @brief Wait till the output of the pipeline is ready. + */ DLL_PUBLIC void daliOutput(daliPipelineHandle* pipe_handle); - DLL_PUBLIC void* daliTensorAt(daliPipelineHandle* pipe_handle, int n); + + /** + * @brief Return the shape of the output tensor + * stored at position `n` in the pipeline. + * This function may only be called after + * calling Output function. + */ DLL_PUBLIC int64_t* daliShapeAt(daliPipelineHandle* pipe_handle, int n); + + /** + * @brief Copy the output tensor stored + * at position `n` in the pipeline. + */ DLL_PUBLIC void daliCopyTensorNTo(daliPipelineHandle* pipe_handle, void* dst, int n); + + /** + * @brief Delete the pipeline object. + */ DLL_PUBLIC void daliDeletePipeline(daliPipelineHandle* pipe_handle); } diff --git a/dali/pipeline/op_graph.cc b/dali/pipeline/op_graph.cc index bcee88f77f4..c7847e5523e 100644 --- a/dali/pipeline/op_graph.cc +++ b/dali/pipeline/op_graph.cc @@ -186,7 +186,7 @@ void OpGraph::AddOp(const OpSpec &spec, const std::string& name) { DALI_ENFORCE(ret.second, "Operator '" + spec.name() + "' has output with name " + name + ", but output " "with this name already exists as output of op '" + - this->node(TensorSourceID(name)).spec.name()); + this->node(TensorSourceID(name)).spec.name() + "'"); } } diff --git a/dali/pipeline/pipeline.cc b/dali/pipeline/pipeline.cc index 27a4b2b280b..70cf5ad3606 100644 --- a/dali/pipeline/pipeline.cc +++ b/dali/pipeline/pipeline.cc @@ -29,9 +29,6 @@ void Pipeline::AddOperator(OpSpec spec, const std::string& inst_name) { DALI_ENFORCE(!built_, "Alterations to the pipeline after " "\"Build()\" has been called are not allowed"); - // Take a copy of the passed OpSpec for serialization purposes - this->op_specs_.push_back(make_pair(inst_name, spec)); - // Validate op device string device = spec.GetArgument("device"); DALI_ENFORCE(device == "cpu" || @@ -148,18 +145,46 @@ void Pipeline::AddOperator(OpSpec spec, const std::string& inst_name) { "Output name insertion failure."); } - // Add the operator to the graph - PrepareOpSpec(&spec); - graph_.AddOp(spec, inst_name); + // Take a copy of the passed OpSpec for serialization purposes + this->op_specs_.push_back(make_pair(inst_name, spec)); + this->op_specs_to_serialize_.push_back(true); } void Pipeline::Build(vector> output_names) { - DeviceGuard g(device_id_); - output_names_ = output_names; DALI_ENFORCE(!built_, "\"Build()\" can only be called once."); DALI_ENFORCE(output_names.size() > 0, "User specified zero outputs."); + + // Creating the executor + if (pipelined_execution_ && async_execution_) { + executor_.reset(new AsyncPipelinedExecutor( + batch_size_, num_threads_, + device_id_, bytes_per_sample_hint_, + set_affinity_, max_num_stream_)); + executor_->Init(); + } else if (pipelined_execution_) { + executor_.reset(new PipelinedExecutor( + batch_size_, num_threads_, + device_id_, bytes_per_sample_hint_, + set_affinity_, max_num_stream_)); + } else if (async_execution_) { + DALI_FAIL("Not implemented."); + } else { + executor_.reset(new Executor( + batch_size_, num_threads_, + device_id_, bytes_per_sample_hint_, + set_affinity_, max_num_stream_)); + } + + // Creating the graph + for (auto& name_op_spec : op_specs_) { + string& inst_name = name_op_spec.first; + OpSpec op_spec = name_op_spec.second; + PrepareOpSpec(&op_spec); + graph_.AddOp(op_spec, inst_name); + } + // Validate the output tensors names vector outputs; for (const auto &name_pair : output_names) { @@ -207,11 +232,17 @@ void Pipeline::Build(vector> output_names) { } } + DeviceGuard d(device_id_); // Load the final graph into the executor executor_->Build(&graph_, outputs); built_ = true; } +void Pipeline::SetOutputNames(vector> output_names) { + output_names_ = output_names; +} + + void Pipeline::RunCPU() { DALI_ENFORCE(built_, "\"Build()\" must be called prior to executing the pipeline."); @@ -242,14 +273,21 @@ void Pipeline::Outputs(DeviceWorkspace *ws) { void Pipeline::SetupCPUInput(std::map::iterator it, int input_idx, OpSpec *spec) { if (!it->second.has_contiguous) { - if (graph_.TensorExists(OpSpec::TensorName("contiguous_" + it->first, "cpu"))) return; + // We check if the make contiguous op already exists + std::string op_name = "__MakeContiguous_" + it->first; + if (std::find_if(op_specs_.begin(), op_specs_.end(), + [&op_name] (const std::pair& p) { + return p.first == op_name;}) != op_specs_.end()) { + return; + } + OpSpec make_contiguous_spec = OpSpec("MakeContiguous") .AddArg("device", "mixed") .AddInput(it->first, "cpu") .AddOutput("contiguous_" + it->first, "cpu"); - PrepareOpSpec(&make_contiguous_spec); - graph_.AddOp(make_contiguous_spec, "__MakeContiguous_" + it->first); + this->op_specs_.push_back(make_pair("__MakeContiguous_" + it->first, make_contiguous_spec)); + this->op_specs_to_serialize_.push_back(false); } // Update the OpSpec to use the contiguous input @@ -262,14 +300,21 @@ void Pipeline::SetupCPUInput(std::map::iterator it, void Pipeline::SetupGPUInput(std::map::iterator it) { if (it->second.has_gpu) return; - if (graph_.TensorExists(OpSpec::TensorName(it->first, "gpu"))) return; + // We check if the copy_to_dev op already exists + std::string op_name = "__Copy_" + it->first; + if (std::find_if(op_specs_.begin(), op_specs_.end(), + [&op_name] (const std::pair& p) { + return p.first == op_name;}) != op_specs_.end()) { + return; + } + OpSpec copy_to_dev_spec = OpSpec("MakeContiguous") .AddArg("device", "mixed") .AddInput(it->first, "cpu") .AddOutput(it->first, "gpu"); - PrepareOpSpec(©_to_dev_spec); - graph_.AddOp(copy_to_dev_spec, "__Copy_" + it->first); + this->op_specs_.push_back(make_pair("__Copy_" + it->first, copy_to_dev_spec)); + this->op_specs_to_serialize_.push_back(false); } void Pipeline::PrepareOpSpec(OpSpec *spec) { @@ -285,6 +330,8 @@ string Pipeline::SerializeToProtobuf() const { dali_proto::PipelineDef pipe; pipe.set_num_threads(this->num_threads()); pipe.set_batch_size(this->batch_size()); + pipe.set_device_id(this->device_id()); + pipe.set_seed(this->original_seed_); // loop over external inputs for (auto &name : external_inputs_) { @@ -293,14 +340,16 @@ string Pipeline::SerializeToProtobuf() const { // loop over ops, create messages and append for (size_t i = 0; i < this->op_specs_.size(); ++i) { - dali_proto::OpDef *op_def = pipe.add_op(); + if (op_specs_to_serialize_[i]) { + dali_proto::OpDef *op_def = pipe.add_op(); - const auto& p = this->op_specs_[i]; - const OpSpec& spec = p.second; + const auto& p = this->op_specs_[i]; + const OpSpec& spec = p.second; - // As long as spec isn't an ExternalSource node, serialize - if (spec.name() != "ExternalSource") { - spec.SerializeToProtobuf(op_def, p.first); + // As long as spec isn't an ExternalSource node, serialize + if (spec.name() != "ExternalSource") { + spec.SerializeToProtobuf(op_def, p.first); + } } } diff --git a/dali/pipeline/pipeline.h b/dali/pipeline/pipeline.h index b9270e9f5de..e905acf42d7 100644 --- a/dali/pipeline/pipeline.h +++ b/dali/pipeline/pipeline.h @@ -84,53 +84,46 @@ class DLL_PUBLIC Pipeline { bool pipelined_execution = true, bool async_execution = true, size_t bytes_per_sample_hint = 0, bool set_affinity = false, int max_num_stream = -1) : - built_(false), batch_size_(batch_size), num_threads_(num_threads), - device_id_(device_id), bytes_per_sample_hint_(bytes_per_sample_hint) { - DALI_ENFORCE(batch_size_ > 0, "Batch size must be greater than 0"); - seed_.resize(MAX_SEEDS); - current_seed = 0; - if (seed != -1) { - std::seed_seq ss{seed}; - ss.generate(seed_.begin(), seed_.end()); - } else { - std::seed_seq ss{time(0)}; - ss.generate(seed_.begin(), seed_.end()); - } - - if (pipelined_execution && async_execution) { - executor_.reset(new AsyncPipelinedExecutor( - batch_size, num_threads, - device_id, bytes_per_sample_hint, - set_affinity, max_num_stream)); - executor_->Init(); - } else if (pipelined_execution) { - executor_.reset(new PipelinedExecutor( - batch_size, num_threads, - device_id, bytes_per_sample_hint, - set_affinity, max_num_stream)); - } else if (async_execution) { - DALI_FAIL("Not implemented."); - } else { - executor_.reset(new Executor( - batch_size, num_threads, - device_id, bytes_per_sample_hint, - set_affinity, max_num_stream)); - } + built_(false) { + Init(batch_size, num_threads, device_id, seed, + pipelined_execution, async_execution, + bytes_per_sample_hint, set_affinity, + max_num_stream); } DLL_PUBLIC inline Pipeline(const string &serialized_pipe, - int batch_size, int num_threads, int device_id, int seed = -1, + int batch_size = -1, int num_threads = -1, int device_id = -1, bool pipelined_execution = true, bool async_execution = true, size_t bytes_per_sample_hint = 0, bool set_affinity = false, - int max_num_stream = -1) : - Pipeline(batch_size, num_threads, device_id, seed, pipelined_execution, - async_execution, bytes_per_sample_hint, set_affinity, - max_num_stream) { + int max_num_stream = -1) : built_(false) { dali_proto::PipelineDef def; def.ParseFromString(serialized_pipe); - this->batch_size_ = def.batch_size(); - this->device_id_ = def.device_id(); + // If not given, take parameters from the + // serialized pipeline + if (batch_size == -1) { + this->batch_size_ = def.batch_size(); + } else { + this->batch_size_ = batch_size; + } + if (device_id == -1) { + this->device_id_ = def.device_id(); + } else { + this->device_id_ = device_id; + } + if (num_threads == -1) { + this->num_threads_ = def.num_threads(); + } else { + this->num_threads_ = num_threads; + } + + Init(this->batch_size_, this->num_threads_, + this->device_id_, def.seed(), + pipelined_execution, + async_execution, + bytes_per_sample_hint, + set_affinity, + max_num_stream); // from serialized pipeline, construct new pipeline // All external inputs @@ -244,6 +237,12 @@ class DLL_PUBLIC Pipeline { Build(this->output_names_); } + /* + * @brief Set name output_names of the pipeline. Used to update the graph without + * running the executor. + */ + void SetOutputNames(vector> output_names); + /** * @brief Run the cpu portion of the pipeline. */ @@ -301,6 +300,34 @@ class DLL_PUBLIC Pipeline { DLL_PUBLIC DISABLE_COPY_MOVE_ASSIGN(Pipeline); private: + /** + * @brief Initializes the Pipeline internal state + */ + void Init(int batch_size, int num_threads, int device_id, + int seed, bool pipelined_execution, bool async_execution, + size_t bytes_per_sample_hint, bool set_affinity, + int max_num_stream) { + this->batch_size_ = batch_size; + this->num_threads_ = num_threads; + this->device_id_ = device_id; + this->original_seed_ = seed; + this->pipelined_execution_ = pipelined_execution; + this->async_execution_ = async_execution; + this->bytes_per_sample_hint_ = bytes_per_sample_hint; + this->set_affinity_ = set_affinity; + this->max_num_stream_ = max_num_stream; + DALI_ENFORCE(batch_size_ > 0, "Batch size must be greater than 0"); + seed_.resize(MAX_SEEDS); + current_seed = 0; + if (seed != -1) { + std::seed_seq ss{seed}; + ss.generate(seed_.begin(), seed_.end()); + } else { + std::seed_seq ss{time(0)}; + ss.generate(seed_.begin(), seed_.end()); + } + } + using EdgeMeta = struct { bool has_cpu, has_gpu, has_contiguous, is_support; }; @@ -349,8 +376,14 @@ class DLL_PUBLIC Pipeline { bool built_; int batch_size_, num_threads_, device_id_; + bool pipelined_execution_; + bool async_execution_; size_t bytes_per_sample_hint_; + int set_affinity_; + int max_num_stream_; + std::vector seed_; + int original_seed_; size_t current_seed; OpGraph graph_; @@ -362,6 +395,7 @@ class DLL_PUBLIC Pipeline { // serialized form vector external_inputs_; vector> op_specs_; + vector op_specs_to_serialize_; vector> output_names_; }; diff --git a/dali/pipeline/pipeline_test.cc b/dali/pipeline/pipeline_test.cc index 4d06ff6a7d9..1b9c320ef79 100644 --- a/dali/pipeline/pipeline_test.cc +++ b/dali/pipeline/pipeline_test.cc @@ -249,6 +249,9 @@ TEST_F(PipelineTestOnce, TestTriggerToContiguous) { .AddInput("data", "cpu") .AddOutput("data_copy", "gpu")); + vector> outputs = {{"data_copy", "gpu"}}; + pipe.Build(outputs); + OpGraph &graph = this->GetGraph(&pipe); // Validate the graph @@ -292,6 +295,9 @@ TEST_F(PipelineTestOnce, TestTriggerCopyToDevice) { .AddInput("data", "gpu") .AddOutput("data_copy", "gpu")); + vector> outputs = {{"data_copy", "gpu"}}; + pipe.Build(outputs); + OpGraph &graph = this->GetGraph(&pipe); // Validate the graph diff --git a/dali/pipeline/proto/dali.proto b/dali/pipeline/proto/dali.proto index bf960d93530..791462a535c 100644 --- a/dali/pipeline/proto/dali.proto +++ b/dali/pipeline/proto/dali.proto @@ -65,4 +65,5 @@ message PipelineDef { // Store all registered outputs repeated InputOutput pipe_outputs = 7; optional int32 device_id = 8 [default = 0]; + optional int32 seed = 9 [default = -1]; } diff --git a/dali/python/backend_impl.cc b/dali/python/backend_impl.cc index 93909606597..a86f6552ec3 100644 --- a/dali/python/backend_impl.cc +++ b/dali/python/backend_impl.cc @@ -397,13 +397,13 @@ PYBIND11_MODULE(backend_impl, m) { // initialize from serialized pipeline .def(py::init( [](string serialized_pipe, - int batch_size, int num_threads, int device_id, int seed = -1, + int batch_size, int num_threads, int device_id, bool pipelined_execution = true, bool async_execution = true, size_t bytes_per_sample_hint = 0, bool set_affinity = false, int max_num_stream = -1) { return std::unique_ptr( new Pipeline(serialized_pipe, - batch_size, num_threads, device_id, seed, pipelined_execution, + batch_size, num_threads, device_id, pipelined_execution, async_execution, bytes_per_sample_hint, set_affinity, max_num_stream)); }), @@ -411,7 +411,6 @@ PYBIND11_MODULE(backend_impl, m) { "batch_size"_a, "num_threads"_a, "device_id"_a, - "seed"_a, "exec_pipelined"_a, "exec_async"_a, "bytes_per_sample_hint"_a = 0, @@ -428,6 +427,10 @@ PYBIND11_MODULE(backend_impl, m) { [](Pipeline *p) { p->Build(); }) + .def("SetOutputNames", + [](Pipeline *p, const std::vector>& outputs) { + p->SetOutputNames(outputs); + }) .def("RunCPU", &Pipeline::RunCPU) .def("RunGPU", &Pipeline::RunGPU) .def("Outputs", diff --git a/dali/python/nvidia/dali/pipeline.py b/dali/python/nvidia/dali/pipeline.py index ab287dbfd67..55d1830691c 100644 --- a/dali/python/nvidia/dali/pipeline.py +++ b/dali/python/nvidia/dali/pipeline.py @@ -18,20 +18,14 @@ from nvidia.dali import tensor as nt class Pipeline(object): - def __init__(self, batch_size, num_threads, device_id, seed = -1, + def __init__(self, batch_size = -1, num_threads = -1, device_id = -1, seed = -1, exec_pipelined=True, exec_async=True, bytes_per_sample=0, set_affinity=False, max_streams=-1): - self._pipe = b.Pipeline(batch_size, - num_threads, - device_id, - seed, - exec_pipelined, - exec_async, - bytes_per_sample, - set_affinity, - max_streams) - self.seed = seed + self._batch_size = batch_size + self._num_threads = num_threads + self._device_id = device_id + self._seed = seed self._exec_pipelined = exec_pipelined self._built = False self._first_iter = True @@ -44,22 +38,33 @@ def __init__(self, batch_size, num_threads, device_id, seed = -1, @property def batch_size(self): - return self._pipe.batch_size() + return self._batch_size @property def num_threads(self): - return self._pipe.num_threads() + return self._num_threads @property def device_id(self): - return self._pipe.device_id() + return self._device_id def epoch_size(self, name = None): + if not self._built: + raise RuntimeError("Pipeline must be builti first.") if name is not None: return self._pipe.epoch_size(name) return self._pipe.epoch_size() def _prepare_graph(self): + self._pipe = b.Pipeline(self._batch_size, + self._num_threads, + self._device_id, + self._seed, + self._exec_pipelined, + self._exec_async, + self._bytes_per_sample, + self._set_affinity, + self._max_streams) outputs = self.define_graph() if (not isinstance(outputs, tuple) and not isinstance(outputs, list)): @@ -127,6 +132,8 @@ def build(self): self._built = True def feed_input(self, ref, data): + if not self._built: + raise RuntimeError("Pipeline must be built first.") if not isinstance(ref, nt.TensorReference): raise TypeError( "Expected argument one to " @@ -144,15 +151,23 @@ def feed_input(self, ref, data): self._pipe.SetExternalTLInput(ref.name, inp) def run_cpu(self): + if not self._built: + raise RuntimeError("Pipeline must be built first.") self._pipe.RunCPU() def run_gpu(self): + if not self._built: + raise RuntimeError("Pipeline must be built first.") self._pipe.RunGPU() def outputs(self): + if not self._built: + raise RuntimeError("Pipeline must be built first.") return self._pipe.Outputs() def run(self): + if not self._built: + raise RuntimeError("Pipeline must be built first.") if self._first_iter and self._exec_pipelined: self.iter_setup() self.run_cpu() @@ -164,27 +179,28 @@ def run(self): return self.outputs() def serialize(self): - if not self._built: - self.build() + if not self._prepared: + self._prepare_graph() + self._pipe.SetOutputNames(self._names_and_devices) return self._pipe.SerializeToProtobuf() def deserialize_and_build(self, serialized_pipeline): - new_pipe = b.Pipeline(serialized_pipeline, - self.batch_size, - self.num_threads, - self.device_id, - self.seed, - self._exec_pipelined, - self._exec_async, - self._bytes_per_sample, - self._set_affinity, - self._max_streams) - self._pipe = new_pipe + self._pipe = b.Pipeline(serialized_pipeline, + self._batch_size, + self._num_threads, + self._device_id, + self._exec_pipelined, + self._exec_async, + self._bytes_per_sample, + self._set_affinity, + self._max_streams) self._prepared = True self._pipe.Build() self._built = True def save_graph_to_dot_file(self, filename): + if not self._built: + raise RuntimeError("Pipeline must be built first.") self._pipe.SaveGraphToDotFile(filename) # defined by the user to construct their graph of operations. diff --git a/dali/tensorflow/daliop.cc b/dali/tensorflow/daliop.cc index 32aaa8036a7..ce34d8738d9 100644 --- a/dali/tensorflow/daliop.cc +++ b/dali/tensorflow/daliop.cc @@ -55,11 +55,11 @@ tf::TensorShape DaliToShape(int64_t* ns) { REGISTER_OP("Dali") .Attr("serialized_pipeline: string") - .Attr("batch_size: int = 128") + .Attr("batch_size: int = -1") .Attr("height: int = 0") .Attr("width: int = 0") - .Attr("num_threads: int = 2") - .Attr("device_id: int = 0") + .Attr("num_threads: int = -1") + .Attr("device_id: int = -1") .Output("batch: float") .Output("label: float") .SetShapeFn([](tf::shape_inference::InferenceContext* c) { diff --git a/dali/test/python/test_pipeline.py b/dali/test/python/test_pipeline.py index b5918a874ff..59e04ddc456 100644 --- a/dali/test/python/test_pipeline.py +++ b/dali/test/python/test_pipeline.py @@ -217,6 +217,50 @@ def iter_setup(self): img_chw = img_chw_test assert(np.sum(np.abs(img_chw - img_chw_test)) == 0) +def test_seed_serialize(): + batch_size = 64 + class HybridPipe(Pipeline): + def __init__(self, batch_size, num_threads, device_id): + super(HybridPipe, self).__init__(batch_size, + num_threads, + device_id, + seed = 12) + self.input = ops.CaffeReader(path = caffe_db_folder, random_shuffle = True) + self.decode = ops.nvJPEGDecoder(device = "mixed", output_type = types.RGB) + self.cmnp = ops.CropMirrorNormalize(device = "gpu", + output_dtype = types.FLOAT, + crop = (224, 224), + image_type = types.RGB, + mean = [128., 128., 128.], + std = [1., 1., 1.]) + self.coin = ops.CoinFlip() + self.uniform = ops.Uniform(range = (0.0,1.0)) + self.iter = 0 + + def define_graph(self): + self.jpegs, self.labels = self.input() + images = self.decode(self.jpegs) + mirror = self.coin() + output = self.cmnp(images, mirror = mirror, crop_pos_x = self.uniform(), crop_pos_y = self.uniform()) + return (output, self.labels) + + def iter_setup(self): + pass + n = 30 + orig_pipe = HybridPipe(batch_size=batch_size, + num_threads=2, + device_id = 0) + s = orig_pipe.serialize() + for i in range(50): + pipe = Pipeline() + pipe.deserialize_and_build(s) + pipe_out = pipe.run() + pipe_out_cpu = pipe_out[0].asCPU() + img_chw_test = pipe_out_cpu.at(n) + if i == 0: + img_chw = img_chw_test + assert(np.sum(np.abs(img_chw - img_chw_test)) == 0) + def test_rotate(): class HybridPipe(Pipeline): def __init__(self, batch_size, num_threads, device_id):