diff --git a/paddle/fluid/framework/ir/ipu/ipu_graph_builder_pass.cc b/paddle/fluid/framework/ir/ipu/ipu_graph_builder_pass.cc index 6061a1b4a03c4..0ed019e6633cb 100644 --- a/paddle/fluid/framework/ir/ipu/ipu_graph_builder_pass.cc +++ b/paddle/fluid/framework/ir/ipu/ipu_graph_builder_pass.cc @@ -29,6 +29,8 @@ #include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/ipu/ipu_backend.h" +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/scope.h" // debug #include "paddle/fluid/framework/ir/pass_tester_helper.h" @@ -51,6 +53,12 @@ void IpuGraphBuilderPass::ApplyImpl(ir::Graph* graph) const { std::shared_ptr ipu_backend = ipu::IpuBackend::GetInstance(); + // For Paddle inference + if (graph->Has(kParamScopeAttr)) { + auto& scope = graph->Get(kParamScopeAttr); + ipu_backend->SetScope(&scope); + } + ipu_backend->Compile(graph, feed_list, fetch_list); VLOG(10) << "Post Graph: "; diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index c002c7a10cb7b..e6be794fad1ce 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -52,6 +52,7 @@ if(WIN32 AND WITH_GPU) cc_library(paddle_inference DEPS ${fluid_modules} ${STATIC_INFERENCE_API}) else() create_static_lib(paddle_inference ${fluid_modules} ${STATIC_INFERENCE_API}) + target_link_libraries(paddle_inference -Wl,--allow-multiple-definition popart_canonicalization_utils) endif() if(NOT APPLE) diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index f3f6baf8ee3b9..d3d0a55b2db6d 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -289,7 +289,9 @@ void CpuPassStrategy::EnableMkldnnBfloat16() { } IpuPassStrategy::IpuPassStrategy() : PassStrategy({}) { - passes_.assign({"ipu_graph_builder_pass", // + passes_.assign({"forward_graph_extract_pass", + "popart_canonicalization_pass", + "ipu_graph_builder_pass", "ipu_runtime_replacer_pass"}); }