Skip to content

Commit

Permalink
Paddle inference support lowerWeight andpopart_canonicalization (Padd…
Browse files Browse the repository at this point in the history
  • Loading branch information
yaozhixin authored Aug 9, 2021
1 parent f1715ca commit 8db9454
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 1 deletion.
8 changes: 8 additions & 0 deletions paddle/fluid/framework/ir/ipu/ipu_graph_builder_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/ipu/ipu_backend.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/scope.h"

// debug
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
Expand All @@ -51,6 +53,12 @@ void IpuGraphBuilderPass::ApplyImpl(ir::Graph* graph) const {

std::shared_ptr<ipu::IpuBackend> ipu_backend = ipu::IpuBackend::GetInstance();

// For Paddle inference
if (graph->Has(kParamScopeAttr)) {
auto& scope = graph->Get<Scope>(kParamScopeAttr);
ipu_backend->SetScope(&scope);
}

ipu_backend->Compile(graph, feed_list, fetch_list);

VLOG(10) << "Post Graph: ";
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/inference/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ if(WIN32 AND WITH_GPU)
cc_library(paddle_inference DEPS ${fluid_modules} ${STATIC_INFERENCE_API})
else()
create_static_lib(paddle_inference ${fluid_modules} ${STATIC_INFERENCE_API})
target_link_libraries(paddle_inference -Wl,--allow-multiple-definition popart_canonicalization_utils)
endif()

if(NOT APPLE)
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,9 @@ void CpuPassStrategy::EnableMkldnnBfloat16() {
}

IpuPassStrategy::IpuPassStrategy() : PassStrategy({}) {
passes_.assign({"ipu_graph_builder_pass", //
passes_.assign({"forward_graph_extract_pass",
"popart_canonicalization_pass",
"ipu_graph_builder_pass",
"ipu_runtime_replacer_pass"});
}

Expand Down

0 comments on commit 8db9454

Please sign in to comment.