Skip to content

Commit

Permalink
enable tensor.mutable_data on compile-time with an option (PaddlePadd…
Browse files Browse the repository at this point in the history
…le#463)

* enable instantiating tensor with a option

Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
Co-authored-by: wangone <2279939962@qq.com>
  • Loading branch information
3 people authored Oct 9, 2021
1 parent c3b823a commit 9f165a1
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 6 deletions.
30 changes: 24 additions & 6 deletions cinn/hlir/framework/graph_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,9 @@ void GraphCompiler::ProcessFunction(const std::vector<ir::LoweredFunc>& lowered_
for (auto& shape_dim : j.buffer_arg()->shape) {
VLOG(3) << shape_dim << ",";
CHECK(shape_dim.is_constant());
shape.push_back((int)(shape_dim.get_constant()));
shape.push_back(static_cast<int>(shape_dim.get_constant()));
}
tensor->Resize(Shape{shape});
tensor->mutable_data<float>(target_);
}
}
function2input_args_[i->name] = input_args;
Expand All @@ -336,6 +335,15 @@ void GraphCompiler::ProcessFunction(const std::vector<ir::LoweredFunc>& lowered_
}

std::unique_ptr<Program> GraphCompiler::Build(const std::string& code) {
GraphCompiler::CompileOptions options;
options.attached_code = code;
options.with_instantiate_variables = true;

auto&& result = Build(options);
return std::move(result.runtime_program);
}

GraphCompiler::CompilationResult GraphCompiler::Build(const GraphCompiler::CompileOptions& options) {
auto topo_order = graph_->topological_order();
auto& nodes = std::get<0>(topo_order);
auto& edges = std::get<1>(topo_order);
Expand Down Expand Up @@ -378,9 +386,20 @@ std::unique_ptr<Program> GraphCompiler::Build(const std::string& code) {
VLOG(3) << "[X86] C Code is:\n" << out;
}

compiler_->Build(build_module, code);
compiler_->Build(build_module, options.attached_code);
if (options.with_instantiate_variables) {
VLOG(3) << "Initantiate all variables on compile-time";
// All variables reside in scope_, so traverse it to instantiate each one
for (auto& name : scope_->var_names()) {
auto* var = scope_->Var<Tensor>(std::string({name.data(), name.size()}));
auto& tensor = absl::get<Tensor>(*var);
tensor->mutable_data<float>(target_);
}
}

return std::unique_ptr<Program>(new Program(scope_, BuildInstructions()));
GraphCompiler::CompilationResult result;
result.runtime_program.reset(new Program(scope_, BuildInstructions()));
return result;
}

std::vector<std::unique_ptr<Instruction>> GraphCompiler::BuildInstructions() {
Expand Down Expand Up @@ -537,7 +556,7 @@ std::vector<std::unique_ptr<Instruction>> GraphCompiler::BuildInstructions() {
int i = 1;
std::string new_op_func = op_func_name + "_" + std::to_string(i);
if (function2input_args_.count(new_op_func) != 0) {
CHECK(function2input_args_.count(op_func_name) > 0);
CHECK_GT(function2input_args_.count(op_func_name), 0);
instr->AddInArgs(function2input_args_[op_func_name]);
instr->AddOutArgs(function2output_args_[op_func_name]);
}
Expand Down Expand Up @@ -633,7 +652,6 @@ std::shared_ptr<Scope> BuildScope(Target target, const std::shared_ptr<Graph>& g
tensor->Resize(Shape{shape});
CHECK_EQ(dtype_dict.at(iter.first), Float(32))
<< "The dtype of node " << iter.first << " is not float! Other dtype is not implemented yet.";
tensor->mutable_data<float>(target);
}
return scope;
}
Expand Down
13 changes: 13 additions & 0 deletions cinn/hlir/framework/graph_compiler.h
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <map>
#include <memory>
#include <string>
#include <utility>
Expand Down Expand Up @@ -102,6 +103,18 @@ class GraphCompiler final {
GraphCompiler(Target target, const std::shared_ptr<Scope>& scope, const std::shared_ptr<Graph>& graph)
: target_(std::move(target)), scope_(scope), graph_(graph), m_builder_(UniqName("module"), target) {}

struct CompilationResult {
std::unique_ptr<Program> runtime_program;
};

struct CompileOptions {
std::string attached_code = "";
bool with_instantiate_variables = false;
};

// Compile with a packing option and result, to be extended easily.
CompilationResult Build(const CompileOptions& options);

std::unique_ptr<Program> Build(const std::string& code = "");

std::string GenSourceCode();
Expand Down

0 comments on commit 9f165a1

Please sign in to comment.