From 7520f2d8525205fa713e4937cd1504cf4a26c0d1 Mon Sep 17 00:00:00 2001 From: superjomn Date: Mon, 7 May 2018 17:18:37 +0800 Subject: [PATCH 01/12] init tensorrt_engine_op --- .../fluid/inference/tensorrt/CMakeLists.txt | 3 +- paddle/fluid/operators/CMakeLists.txt | 1 + paddle/fluid/operators/tensorrt_engine_op.cc | 1 + paddle/fluid/operators/tensorrt_engine_op.h | 36 +++++++++++++++++++ 4 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/tensorrt_engine_op.cc create mode 100644 paddle/fluid/operators/tensorrt_engine_op.h diff --git a/paddle/fluid/inference/tensorrt/CMakeLists.txt b/paddle/fluid/inference/tensorrt/CMakeLists.txt index c8b656394b403c..3cec55ceb6af2b 100644 --- a/paddle/fluid/inference/tensorrt/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/CMakeLists.txt @@ -1,5 +1,6 @@ +nv_library(tensorrt_engine SRC engine.cc DEPS dynload_cuda) nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) -nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda) +nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS tensorrt_engine) nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor) set(ENGINE_FILE ${CMAKE_CURRENT_SOURCE_DIR}/engine.cc) add_subdirectory(convert) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 256aded8ca234a..5ffb59b8335910 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -208,6 +208,7 @@ op_library(cross_entropy_op DEPS cross_entropy) op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) op_library(softmax_op DEPS softmax) op_library(sequence_softmax_op DEPS softmax) +op_library(tensorrt_engine_op DEPS tensorrt_engine) op_library(sum_op DEPS selected_rows_functor) op_library(sgd_op DEPS selected_rows_functor) op_library(print_op DEPS lod_tensor) diff --git a/paddle/fluid/operators/tensorrt_engine_op.cc b/paddle/fluid/operators/tensorrt_engine_op.cc new file mode 100644 index 00000000000000..c297fbd1124838 --- /dev/null +++ b/paddle/fluid/operators/tensorrt_engine_op.cc @@ -0,0 +1 @@ +#include "paddle/fluid/operators/tensorrt_engine_op.h" diff --git a/paddle/fluid/operators/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt_engine_op.h new file mode 100644 index 00000000000000..cd6334f1b0e66e --- /dev/null +++ b/paddle/fluid/operators/tensorrt_engine_op.h @@ -0,0 +1,36 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace inference { +namespace analysis { + +template +class TensorRTEngineKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override {} +}; + +class TensorRTEngineOp : public framework::OperatorWithKernel { + protected: + void Build() {} +}; + +} // namespace analysis +} // namespace inference +} // namespace paddle From 783da840439f89eab72e2c88a86c9e436359b246 Mon Sep 17 00:00:00 2001 From: superjomn Date: Tue, 8 May 2018 13:47:17 +0800 Subject: [PATCH 02/12] init operator --- paddle/fluid/framework/operator.h | 8 +++++- paddle/fluid/operators/CMakeLists.txt | 1 + paddle/fluid/operators/tensorrt_engine_op.h | 25 +++++++++++-------- .../operators/tensorrt_engine_op_test.cc | 21 ++++++++++++++++ 4 files changed, 44 insertions(+), 11 deletions(-) create mode 100644 paddle/fluid/operators/tensorrt_engine_op_test.cc diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index d373c48b1a75c5..f0d5bc5e5da7e8 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -79,6 +79,10 @@ class OperatorBase { virtual ~OperatorBase() {} + // Do some preparation in construction phase. It leaves empty by default, free + // to overload it and it will be called in the construction function. + void Prepare(){}; + /// Executor will call this interface function to Run an op. // The implementation should be written at RunImpl void Run(const Scope& scope, const platform::Place& place); @@ -163,7 +167,9 @@ class OperatorBase { const ::paddle::framework::VariableNameMap& inputs, \ const ::paddle::framework::VariableNameMap& outputs, \ const paddle::framework::AttributeMap& attrs) \ - : parent_cls(type, inputs, outputs, attrs) {} + : parent_cls(type, inputs, outputs, attrs) { \ + Prepare(); \ + } class NOP : public OperatorBase { public: diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 5ffb59b8335910..3d277cfd0016d7 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -280,3 +280,4 @@ cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op) nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context) nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor) +nv_test(tensorrt_engine_op_test SRCS tensorrt_engine_op_test.cc DEPS tensorrt_engine_op tensorrt_engine) diff --git a/paddle/fluid/operators/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt_engine_op.h index cd6334f1b0e66e..83c9bc9981a4e8 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt_engine_op.h @@ -17,20 +17,25 @@ #include "paddle/fluid/framework/operator.h" namespace paddle { -namespace inference { -namespace analysis { +namespace operators { + +class TensorRTEngineOp : public framework::OperatorWithKernel { + protected: + // Build the engine. + void Prepare() { + // Call converter, input the BlockDesc and build the network. + } +}; template class TensorRTEngineKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext &context) const override {} -}; - -class TensorRTEngineOp : public framework::OperatorWithKernel { - protected: - void Build() {} + void Compute(const framework::ExecutionContext &context) const override { + // Convert input tensor from fluid to engine. + // Execute the engine. + // Convert output tensor from engine to fluid. + } }; -} // namespace analysis -} // namespace inference +} // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/tensorrt_engine_op_test.cc b/paddle/fluid/operators/tensorrt_engine_op_test.cc new file mode 100644 index 00000000000000..715931ea04ccd5 --- /dev/null +++ b/paddle/fluid/operators/tensorrt_engine_op_test.cc @@ -0,0 +1,21 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/operators/tensorrt_engine_op.h" +#include + +namespace paddle { + +namespace operators {} // namespace operators +} // namespace paddle From 70eeab0c1bd4cee52bbb5dfa89140c54e9636e7f Mon Sep 17 00:00:00 2001 From: superjomn Date: Sun, 13 May 2018 16:50:25 +0800 Subject: [PATCH 03/12] add all --- paddle/fluid/operators/tensorrt_engine_op.cc | 31 ++++++++++++++++++++ paddle/fluid/operators/tensorrt_engine_op.h | 6 ++-- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/tensorrt_engine_op.cc b/paddle/fluid/operators/tensorrt_engine_op.cc index c297fbd1124838..c6c12855b0c7fa 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.cc +++ b/paddle/fluid/operators/tensorrt_engine_op.cc @@ -1 +1,32 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + #include "paddle/fluid/operators/tensorrt_engine_op.h" +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/engine.h" + +namespace paddle { +namespace operators { + +void paddle::operators::TensorRTEngineOp::Prepare() { + // Get the ProgramDesc and pass to convert. + const auto& block = Attr("subgraph"); + auto max_batch = Attr("max_batch"); + auto max_workspace = Attr("max_workspace"); + inference::tensorrt::TensorRTEngine engine(max_batch, max_workspace, nullptr); + inference::tensorrt::OpConverter::Global().ConvertBlock(block, &engine); +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt_engine_op.h index 83c9bc9981a4e8..10c6b6062360e3 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt_engine_op.h @@ -20,11 +20,11 @@ namespace paddle { namespace operators { class TensorRTEngineOp : public framework::OperatorWithKernel { + public: + TensorRTEngineOp() = default; protected: // Build the engine. - void Prepare() { - // Call converter, input the BlockDesc and build the network. - } + void Prepare(); }; template From a9a5c8d92c0bcbbc1420758597213a9bd04d6efd Mon Sep 17 00:00:00 2001 From: superjomn Date: Sun, 13 May 2018 20:28:54 +0800 Subject: [PATCH 04/12] init --- paddle/fluid/operators/tensorrt_engine_op.cc | 47 +++++++++++++++++--- paddle/fluid/operators/tensorrt_engine_op.h | 34 +++++++++++--- 2 files changed, 69 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/tensorrt_engine_op.cc b/paddle/fluid/operators/tensorrt_engine_op.cc index c6c12855b0c7fa..b75bf2d31b7728 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.cc +++ b/paddle/fluid/operators/tensorrt_engine_op.cc @@ -13,20 +13,53 @@ limitations under the License. */ #include "paddle/fluid/operators/tensorrt_engine_op.h" + +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" -#include "paddle/fluid/inference/tensorrt/engine.h" namespace paddle { namespace operators { -void paddle::operators::TensorRTEngineOp::Prepare() { +template +void paddle::operators::TensorRTEngineKernel::Prepare( + const framework::ExecutionContext &context) const { // Get the ProgramDesc and pass to convert. - const auto& block = Attr("subgraph"); - auto max_batch = Attr("max_batch"); - auto max_workspace = Attr("max_workspace"); - inference::tensorrt::TensorRTEngine engine(max_batch, max_workspace, nullptr); - inference::tensorrt::OpConverter::Global().ConvertBlock(block, &engine); + const auto &block = context.Attr("subgraph"); + max_batch_ = context.Attr("max_batch"); + auto max_workspace = context.Attr("max_workspace"); + engine_.reset(new inference::tensorrt::TensorRTEngine( + max_batch_, max_workspace, nullptr)); + inference::tensorrt::OpConverter::Global().ConvertBlock(block, engine_.get()); + engine_->FreezeNetwork(); } +class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { + public: + TensorRTEngineOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Xs", "A list of inputs.").AsDuplicable(); + AddOutput("Ys", "A list of outputs").AsDuplicable(); + AddAttr("subgraph", "the subgraph"); + } +}; + +class TensorRTEngineInferVarType : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override {} +}; + } // namespace operators } // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(tensorrt_engine, ops::TensorRTEngineOp, + ops::TensorRTEngineOpMaker, TensorRTEngineOpMaker); + +REGISTER_OP_CPU_KERNEL( + tensorrt_engine, + ops::TensorRTEngineKernel, + ops::TensorRTEngineKernel, + ops::TensorRTEngineKernel, + ops::TensorRTEngineKernel); diff --git a/paddle/fluid/operators/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt_engine_op.h index 10c6b6062360e3..c326386399659a 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt_engine_op.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/inference/tensorrt/engine.h" namespace paddle { namespace operators { @@ -22,19 +23,42 @@ namespace operators { class TensorRTEngineOp : public framework::OperatorWithKernel { public: TensorRTEngineOp() = default; - protected: - // Build the engine. - void Prepare(); }; template class TensorRTEngineKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext &context) const override { + void Compute(const framework::ExecutionContext& context) const override { + if (!engine_) { + Prepare(context); + } + auto& inputs = context.Inputs("Xs"); + PADDLE_ENFORCE(!inputs.empty(), "should pass more than one inputs"); + auto* var0 = context.Input(inputs.front()); + PADDLE_ENFORCE_NOT_NULL(var0); + auto* tensor0 = var0->GetMutable(); + const batch_size = tensor0->dims()[0]; + // Convert input tensor from fluid to engine. + for (const auto& x : context.Inputs("Xs")) { + // convert input and copy to TRT engine's buffer + } // Execute the engine. - // Convert output tensor from engine to fluid. + PADDLE_ENFORCE_GT(max_batch_, 0); + engine_->Execute(max_batch_); + // Convert output tensor from engine to fluid + for (const auto& y : context.Outputs("Ys")) { + // convert output and copy to fluid. + } } + + protected: + // Build the engine. + void Prepare(const framework::ExecutionContext& context) const; + + private: + mutable std::unique_ptr engine_; + mutable int max_batch_{0}; }; } // namespace operators From cfa184cf756a65e584fb25eaa0ac58fc61c53e4e Mon Sep 17 00:00:00 2001 From: superjomn Date: Tue, 29 May 2018 15:40:02 +0800 Subject: [PATCH 05/12] add tensorrt op --- paddle/fluid/inference/tensorrt/engine.cc | 16 ++++++- paddle/fluid/inference/tensorrt/engine.h | 5 +- paddle/fluid/operators/tensorrt_engine_op.cc | 10 ++-- paddle/fluid/operators/tensorrt_engine_op.h | 50 ++++++++++++++------ 4 files changed, 58 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index fb27c8394c1f94..f387a231785a2e 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -131,6 +131,20 @@ void* TensorRTEngine::GetOutputInGPU(const std::string& name) { return buffer(name).buffer; } +void TensorRTEngine::GetOutputInGPU(const std::string& name, void* dst, + size_t max_size) { + // determine data size + auto it = buffer_sizes_.find(name); + PADDLE_ENFORCE(it != buffer_sizes_.end()); + PADDLE_ENFORCE_GT(it->second, 0); + PADDLE_ENFORCE_GE(max_size, it->second); + auto& buf = buffer(name); + PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before"); + PADDLE_ENFORCE_EQ(cudaMemcpyAsync(dst, buf.buffer, it->second, + cudaMemcpyDeviceToDevice, *stream_), + 0); +} + void TensorRTEngine::GetOutputInCPU(const std::string& name, void* dst, size_t max_size) { // determine data size @@ -152,7 +166,7 @@ Buffer& TensorRTEngine::buffer(const std::string& name) { return buffers_[slot_offset]; } -void TensorRTEngine::SetInputFromCPU(const std::string& name, void* data, +void TensorRTEngine::SetInputFromCPU(const std::string& name, const void* data, size_t size) { auto& buf = buffer(name); PADDLE_ENFORCE_NOT_NULL(buf.buffer); diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index b8298c6059e864..fdd100670d59ee 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -92,13 +92,14 @@ class TensorRTEngine : public EngineBase { cudaStream_t* stream() { return stream_; } // Fill an input from CPU memory with name and size. - void SetInputFromCPU(const std::string& name, void* data, size_t size); + void SetInputFromCPU(const std::string& name, const void* data, size_t size); // TODO(Superjomn) is this method necessary given that buffer(xxx) can be // accessed directly. Fill an input from GPU memory with name and size. - void SetInputFromGPU(const std::string& name, void* data, size_t size); + void SetInputFromGPU(const std::string& name, const void* data, size_t size); // Get an output called name, the output of tensorrt is in GPU, so this method // will just return the output's GPU memory address. void* GetOutputInGPU(const std::string& name); + void GetOutputInGPU(const std::string& name, void* dst, size_t max_size); // LOW EFFICENCY! Get output to CPU, this will trigger a memory copy from GPU // to CPU. void GetOutputInCPU(const std::string& name, void* dst, size_t max_size); diff --git a/paddle/fluid/operators/tensorrt_engine_op.cc b/paddle/fluid/operators/tensorrt_engine_op.cc index a029656ae97a5b..7984be0fed62a1 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.cc +++ b/paddle/fluid/operators/tensorrt_engine_op.cc @@ -30,17 +30,17 @@ void paddle::operators::TensorRTEngineKernel::Prepare( auto max_workspace = context.Attr("max_workspace"); engine_.reset(new inference::tensorrt::TensorRTEngine( max_batch_, max_workspace, nullptr)); - inference::Singleton::Global().ConvertBlock(block, engine_.get()); + inference::Singleton::Global().ConvertBlock( + block, engine_.get()); engine_->FreezeNetwork(); } class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { public: - TensorRTEngineOpMaker(framework::OpProto *proto, OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { + void Make() override { AddInput("Xs", "A list of inputs.").AsDuplicable(); AddOutput("Ys", "A list of outputs").AsDuplicable(); - AddAttr("subgraph", "the subgraph"); + AddAttr("subgraph", "the subgraph"); } }; @@ -56,7 +56,7 @@ class TensorRTEngineInferVarType : public framework::VarTypeInference { namespace ops = paddle::operators; REGISTER_OPERATOR(tensorrt_engine, ops::TensorRTEngineOp, - ops::TensorRTEngineOpMaker, TensorRTEngineOpMaker); + ops::TensorRTEngineOpMaker, ops::TensorRTEngineOpMaker); REGISTER_OP_CPU_KERNEL( tensorrt_engine, diff --git a/paddle/fluid/operators/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt_engine_op.h index 4795e21cdda279..1445643511edff 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt_engine_op.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/tensorrt/engine.h" namespace paddle { @@ -22,7 +23,19 @@ namespace operators { class TensorRTEngineOp : public framework::OperatorWithKernel { public: - TensorRTEngineOp() = default; + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override {} + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + framework::OpKernelType kt = framework::OpKernelType( + framework::ToDataType( + ctx.Input("pre_ids")->type()), + platform::CPUPlace()); + return kt; + } }; template @@ -32,12 +45,13 @@ class TensorRTEngineKernel : public framework::OpKernel { if (!engine_) { Prepare(context); } - auto& inputs = context.Inputs("Xs"); - PADDLE_ENFORCE(!inputs.empty(), "should pass more than one inputs"); - auto* var0 = context.Input(inputs.front()); - PADDLE_ENFORCE_NOT_NULL(var0); - auto* tensor0 = var0->GetMutable(); - const batch_size = tensor0->dims()[0]; + auto input_names = context.op().Inputs("Xs"); + PADDLE_ENFORCE(!input_names.empty(), "should pass more than one inputs"); + // Try to determine a batch_size + auto* tensor0 = context.Input(input_names.front()); + PADDLE_ENFORCE_NOT_NULL(tensor0); + int batch_size = tensor0->dims()[0]; + PADDLE_ENFORCE_LE(batch_size, max_batch_); // Convert input tensor from fluid to engine. for (const auto& x : context.Inputs("Xs")) { @@ -46,29 +60,35 @@ class TensorRTEngineKernel : public framework::OpKernel { PADDLE_ENFORCE_NOT_NULL(v, "no variable called %s", x); auto& t = v->Get(); if (platform::is_cpu_place(t.place())) { - engine_->SetInputFromCPU(x, static_cast(t.data()), t.memory_size()); + engine_->SetInputFromCPU(x, static_cast(t.data()), + t.memory_size()); } else { - engine_->SetInputFromGPU(x, static_cast(t.data()), t.memory_size()); + engine_->SetInputFromGPU(x, static_cast(t.data()), + t.memory_size()); } } // Execute the engine. - PADDLE_ENFORCE_GT(max_batch_, 0); - engine_->Execute(max_batch_); + PADDLE_ENFORCE_GT(batch_size, 0); + engine_->Execute(batch_size); // Convert output tensor from engine to fluid for (const auto& y : context.Outputs("Ys")) { // convert output and copy to fluid. nvinfer1::ITensor* trt_t = engine_->GetITensor(y); - auto* trt_v = engine_->GetOutputInGPU(y); auto dims = trt_t->getDimensions(); // Use the output ITensor's dims to reshape the Fluid Tensor. - std::vector ddim(dims.d, dims.d+dims.nbDims); + std::vector ddim(dims.d, dims.d + dims.nbDims); - auto *fluid_v = context.scope().FindVar(y); + auto* fluid_v = context.scope().FindVar(y); PADDLE_ENFORCE_NOT_NULL(fluid_v, "no output variable called %s", y); auto* fluid_t = fluid_v->GetMutable(); fluid_t->Resize(framework::make_ddim(ddim)); + auto size = inference::analysis::AccuDims(dims.d, dims.nbDims); if (platform::is_cpu_place(fluid_t->place())) { - engine_-> + engine_->GetOutputInCPU( + y, fluid_t->mutable_data(platform::CPUPlace()), size); + } else { + engine_->GetOutputInCPU( + y, fluid_t->mutable_data(platform::CUDAPlace()), size); } } } From cf6f13251f5fc362bef2e9ace3028baf6a3010a2 Mon Sep 17 00:00:00 2001 From: superjomn Date: Tue, 29 May 2018 15:41:54 +0800 Subject: [PATCH 06/12] restore operator.h --- paddle/fluid/framework/operator.h | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index b27d86efff285b..2f480e00c100d5 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -78,10 +78,6 @@ class OperatorBase { virtual ~OperatorBase() {} - // Do some preparation in construction phase. It leaves empty by default, free - // to overload it and it will be called in the construction function. - void Prepare(){}; - /// Executor will call this interface function to Run an op. // The implementation should be written at RunImpl void Run(const Scope& scope, const platform::Place& place); @@ -166,9 +162,7 @@ class OperatorBase { const ::paddle::framework::VariableNameMap& inputs, \ const ::paddle::framework::VariableNameMap& outputs, \ const paddle::framework::AttributeMap& attrs) \ - : parent_cls(type, inputs, outputs, attrs) { \ - Prepare(); \ - } + : parent_cls(type, inputs, outputs, attrs) {} class NOP : public OperatorBase { public: From c4e29a5bb66a6646936f4b466aa2471504d8eca4 Mon Sep 17 00:00:00 2001 From: superjomn Date: Tue, 29 May 2018 15:53:55 +0800 Subject: [PATCH 07/12] update --- paddle/fluid/inference/tensorrt/engine.h | 3 ++- paddle/fluid/operators/tensorrt_engine_op.h | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index fdd100670d59ee..d9d3163b66d4c4 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -97,8 +97,9 @@ class TensorRTEngine : public EngineBase { // accessed directly. Fill an input from GPU memory with name and size. void SetInputFromGPU(const std::string& name, const void* data, size_t size); // Get an output called name, the output of tensorrt is in GPU, so this method - // will just return the output's GPU memory address. + // Return the output's GPU memory address without copy. void* GetOutputInGPU(const std::string& name); + // Copy data into dst inside the GPU device. void GetOutputInGPU(const std::string& name, void* dst, size_t max_size); // LOW EFFICENCY! Get output to CPU, this will trigger a memory copy from GPU // to CPU. diff --git a/paddle/fluid/operators/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt_engine_op.h index 1445643511edff..cd156b225de5ff 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt_engine_op.h @@ -87,7 +87,7 @@ class TensorRTEngineKernel : public framework::OpKernel { engine_->GetOutputInCPU( y, fluid_t->mutable_data(platform::CPUPlace()), size); } else { - engine_->GetOutputInCPU( + engine_->GetOutputInGPU( y, fluid_t->mutable_data(platform::CUDAPlace()), size); } } From 48d7b702bcffae85069cf806e0813124af6f4d07 Mon Sep 17 00:00:00 2001 From: superjomn Date: Tue, 29 May 2018 17:00:15 +0800 Subject: [PATCH 08/12] add if --- paddle/fluid/operators/CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index c4299191a03853..6e8f5aa52b8c5d 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -221,7 +221,9 @@ op_library(cross_entropy_op DEPS cross_entropy) op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) op_library(softmax_op DEPS softmax) op_library(sequence_softmax_op DEPS softmax) -op_library(tensorrt_engine_op DEPS tensorrt_engine) +if (TENSORRT_FOUND) + op_library(tensorrt_engine_op DEPS tensorrt_engine) +endif() op_library(sum_op DEPS selected_rows_functor) op_library(sgd_op DEPS selected_rows_functor) op_library(print_op DEPS lod_tensor) From 3f8b9dbc940723bdec47f333602ec9f8cf041e3d Mon Sep 17 00:00:00 2001 From: superjomn Date: Tue, 29 May 2018 18:57:27 +0800 Subject: [PATCH 09/12] update --- paddle/fluid/operators/CMakeLists.txt | 3 +-- paddle/fluid/operators/tensorrt_engine_op.cc | 1 - .../operators/tensorrt_engine_op_test.cc | 21 ------------------- 3 files changed, 1 insertion(+), 24 deletions(-) delete mode 100644 paddle/fluid/operators/tensorrt_engine_op_test.cc diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 6e8f5aa52b8c5d..e5c75cdb77fd9c 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -221,7 +221,7 @@ op_library(cross_entropy_op DEPS cross_entropy) op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) op_library(softmax_op DEPS softmax) op_library(sequence_softmax_op DEPS softmax) -if (TENSORRT_FOUND) +if (WITH_GPU && TENSORRT_FOUND) op_library(tensorrt_engine_op DEPS tensorrt_engine) endif() op_library(sum_op DEPS selected_rows_functor) @@ -300,4 +300,3 @@ cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op) nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context) nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor) -nv_test(tensorrt_engine_op_test SRCS tensorrt_engine_op_test.cc DEPS tensorrt_engine_op tensorrt_engine) diff --git a/paddle/fluid/operators/tensorrt_engine_op.cc b/paddle/fluid/operators/tensorrt_engine_op.cc index 7984be0fed62a1..170e89ce9dae26 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.cc +++ b/paddle/fluid/operators/tensorrt_engine_op.cc @@ -13,7 +13,6 @@ limitations under the License. */ #include "paddle/fluid/operators/tensorrt_engine_op.h" - #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/utils/singleton.h" diff --git a/paddle/fluid/operators/tensorrt_engine_op_test.cc b/paddle/fluid/operators/tensorrt_engine_op_test.cc deleted file mode 100644 index 715931ea04ccd5..00000000000000 --- a/paddle/fluid/operators/tensorrt_engine_op_test.cc +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/fluid/operators/tensorrt_engine_op.h" -#include - -namespace paddle { - -namespace operators {} // namespace operators -} // namespace paddle From f57c56fbbbd48bd3a92fc86456490f38785beb5a Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Tue, 29 May 2018 19:40:06 +0800 Subject: [PATCH 10/12] Update CMakeLists.txt --- paddle/fluid/operators/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index e5c75cdb77fd9c..c274e073102f47 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -221,7 +221,7 @@ op_library(cross_entropy_op DEPS cross_entropy) op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) op_library(softmax_op DEPS softmax) op_library(sequence_softmax_op DEPS softmax) -if (WITH_GPU && TENSORRT_FOUND) +if (WITH_GPU AND TENSORRT_FOUND) op_library(tensorrt_engine_op DEPS tensorrt_engine) endif() op_library(sum_op DEPS selected_rows_functor) From e18190ae7d6689dcc931eb822680a3df42e3566c Mon Sep 17 00:00:00 2001 From: superjomn Date: Wed, 30 May 2018 08:14:57 +0800 Subject: [PATCH 11/12] update --- paddle/fluid/inference/tensorrt/engine.cc | 10 ++++++++++ paddle/fluid/operators/tensorrt_engine_op.cc | 5 +++++ paddle/fluid/operators/tensorrt_engine_op.h | 4 ++++ 3 files changed, 19 insertions(+) diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index f387a231785a2e..a88236ae98e181 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -176,6 +176,16 @@ void TensorRTEngine::SetInputFromCPU(const std::string& name, const void* data, cudaMemcpyHostToDevice, *stream_)); } +void TensorRTEngine::SetInputFromGPU(const std::string& name, const void* data, + size_t size) { + auto& buf = buffer(name); + PADDLE_ENFORCE_NOT_NULL(buf.buffer); + PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small"); + PADDLE_ENFORCE(buf.device == DeviceType::GPU); + PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size, + cudaMemcpyDeviceToDevice, *stream_)); +} + void TensorRTEngine::SetITensor(const std::string& name, nvinfer1::ITensor* tensor) { PADDLE_ENFORCE(tensor != nullptr); diff --git a/paddle/fluid/operators/tensorrt_engine_op.cc b/paddle/fluid/operators/tensorrt_engine_op.cc index 170e89ce9dae26..83e768b4dc9c60 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.cc +++ b/paddle/fluid/operators/tensorrt_engine_op.cc @@ -12,6 +12,8 @@ See the License for the specific language governing permissions and limitations under the License. */ +#ifdef PADDLE_WITH_CUDA + #include "paddle/fluid/operators/tensorrt_engine_op.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" @@ -40,6 +42,7 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Xs", "A list of inputs.").AsDuplicable(); AddOutput("Ys", "A list of outputs").AsDuplicable(); AddAttr("subgraph", "the subgraph"); + AddComment("TensorRT engine operator."); } }; @@ -63,3 +66,5 @@ REGISTER_OP_CPU_KERNEL( ops::TensorRTEngineKernel, ops::TensorRTEngineKernel, ops::TensorRTEngineKernel); + +#endif // PADDLE_WITH_CUDA diff --git a/paddle/fluid/operators/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt_engine_op.h index cd156b225de5ff..59cece29ebfe4f 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt_engine_op.h @@ -14,6 +14,8 @@ #pragma once +#ifdef PADDLE_WITH_CUDA + #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/tensorrt/engine.h" @@ -104,3 +106,5 @@ class TensorRTEngineKernel : public framework::OpKernel { } // namespace operators } // namespace paddle + +#endif // PADDLE_WITH_CUDA From 99176a5fc93c89fae199e18b5781e4f85b265c9c Mon Sep 17 00:00:00 2001 From: superjomn Date: Wed, 30 May 2018 17:22:49 +0800 Subject: [PATCH 12/12] change void* --- paddle/fluid/operators/tensorrt_engine_op.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt_engine_op.h index 59cece29ebfe4f..fe273d386c529b 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt_engine_op.h @@ -62,10 +62,10 @@ class TensorRTEngineKernel : public framework::OpKernel { PADDLE_ENFORCE_NOT_NULL(v, "no variable called %s", x); auto& t = v->Get(); if (platform::is_cpu_place(t.place())) { - engine_->SetInputFromCPU(x, static_cast(t.data()), + engine_->SetInputFromCPU(x, static_cast(t.data()), t.memory_size()); } else { - engine_->SetInputFromGPU(x, static_cast(t.data()), + engine_->SetInputFromGPU(x, static_cast(t.data()), t.memory_size()); } }