Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#48 from Superjomn/fea/make-lowered_fu…
Browse files Browse the repository at this point in the history
…nc-ir

fea/make lowered func ir
  • Loading branch information
Superjomn authored Feb 27, 2020
2 parents d9a1874 + 785554f commit 18a20ce
Show file tree
Hide file tree
Showing 18 changed files with 261 additions and 161 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
cmake-build*
build*
.idea*
*.html
18 changes: 10 additions & 8 deletions cinn/backends/codegen_c.cc
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
#include "cinn/backends/codegen_c.h"

#include "cinn/ir/lowered_func.h"

namespace cinn {
namespace backends {

CodeGenC::CodeGenC(std::ostream &os, Target target) : ir::IrPrinter(os), target_(target) {}

void CodeGenC::Compile(const lang::Module &module) {}
void CodeGenC::Compile(const lang::LoweredFunc &function) {
os() << "void " << function.name;
void CodeGenC::Compile(const ir::LoweredFunc &function) {
os() << "void " << function->name;

// output arguments
os() << "(";

auto print_arg = [&](const lang::Argument &arg) {
auto print_arg = [&](const ir::Argument &arg) {
if (arg.is_buffer()) {
os() << "struct cinn_buffer_t *";
} else if (arg.is_scalar()) {
Expand All @@ -22,20 +24,20 @@ void CodeGenC::Compile(const lang::LoweredFunc &function) {
os() << arg.name;
};

for (int i = 0; i < function.args.size() - 1; i++) {
print_arg(function.args[i]);
for (int i = 0; i < function->args.size() - 1; i++) {
print_arg(function->args[i]);
os() << ", ";
}
if (function.args.size() >= 1) {
print_arg(function.args.back());
if (function->args.size() >= 1) {
print_arg(function->args.back());
}

os() << ")";

DoIndent();
os() << "{\n";

Print(function.body);
Print(function->body);

DoIndent();
os() << "}";
Expand Down
3 changes: 2 additions & 1 deletion cinn/backends/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "cinn/ir/function.h"
#include "cinn/ir/ir.h"
#include "cinn/ir/ir_printer.h"
#include "cinn/ir/lowered_func.h"
#include "cinn/lang/module.h"

namespace cinn {
Expand All @@ -24,7 +25,7 @@ class CodeGenC : public ir::IrPrinter {
void Compile(const lang::Module& module);

protected:
void Compile(const lang::LoweredFunc& function);
void Compile(const ir::LoweredFunc& function);
void Compile(const ir::Buffer& buffer);

std::string PrintType(Type type);
Expand Down
1 change: 1 addition & 0 deletions cinn/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ set(srcs
ir_mutator.cc
function.cc
function_definition.cc
lowered_func.cc
ir_operators.cc
buffer.cc
function_base.cc
Expand Down
5 changes: 5 additions & 0 deletions cinn/ir/ir_mutator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,10 @@ void IRMutator::Visit(const _Tensor_ *expr, Expr *op) {
}
}

void IRMutator::Visit(const _LoweredFunc_ *expr, Expr *op) {
auto *node = op->As<_LoweredFunc_>();
IRVisitorBase::Visit(&node->body, &node->body);
}

} // namespace ir
} // namespace cinn
28 changes: 28 additions & 0 deletions cinn/ir/ir_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

#include <vector>

#include "cinn/ir/lowered_func.h"
#include "cinn/lang/module.h"
#include "cinn/lang/tensor.h"
#include "cinn/utils/string.h"

namespace cinn {
namespace ir {
Expand Down Expand Up @@ -173,6 +176,27 @@ void IrPrinter::Visit(const _Tensor_ *x) {
}
os_ << ")";
}
void IrPrinter::Visit(const _LoweredFunc_ *f) {
os_ << "function " << f->name << " ";

std::vector<std::string> arg_names;
for (auto &arg : f->args) {
arg_names.push_back(arg.name);
}
os_ << "(" << utils::Join(arg_names, ", ");

DoIndent();
os_ << "{";

IncIndent();

Print(f->body);

DecIndent();

DoIndent();
os_ << "}";
}
std::ostream &operator<<(std::ostream &os, Expr a) {
std::stringstream ss;
IrPrinter printer(ss);
Expand All @@ -181,5 +205,9 @@ std::ostream &operator<<(std::ostream &os, Expr a) {
return os;
}

std::ostream &operator<<(std::ostream &os, const ir::LoweredFunc &f) {}

std::ostream &operator<<(std::ostream &os, const lang::Module &m);

} // namespace ir
} // namespace cinn
8 changes: 8 additions & 0 deletions cinn/ir/ir_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
#include "cinn/ir/ir_visitor.h"

namespace cinn {

namespace lang {
class Module;
class LoweredFunc;
} // namespace lang

namespace ir {

struct IrPrinter : public IRVisitor {
Expand Down Expand Up @@ -72,6 +78,7 @@ struct IrPrinter : public IRVisitor {
void Visit(const _IterVar_ *x) override {}
void Visit(const _Buffer_ *x) override;
void Visit(const _Tensor_ *x) override;
void Visit(const _LoweredFunc_ *x) override;

private:
std::ostream &os_;
Expand All @@ -80,6 +87,7 @@ struct IrPrinter : public IRVisitor {
};

std::ostream &operator<<(std::ostream &os, Expr a);
std::ostream &operator<<(std::ostream &os, const lang::Module &m);

} // namespace ir
} // namespace cinn
1 change: 1 addition & 0 deletions cinn/ir/ir_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "cinn/ir/buffer.h"
#include "cinn/ir/ir.h"
#include "cinn/ir/lowered_func.h"
#include "cinn/lang/tensor.h"

namespace cinn {
Expand Down
30 changes: 30 additions & 0 deletions cinn/ir/lowered_func.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include "cinn/ir/lowered_func.h"

#include "cinn/common/common.h"

namespace cinn {
namespace ir {

const _LoweredFunc_* LoweredFunc::operator->() const { return As<_LoweredFunc_>(); }
_LoweredFunc_* LoweredFunc::operator->() { return As<_LoweredFunc_>(); }

LoweredFunc _LoweredFunc_::Make(const std::string& name, const std::vector<Argument>& args, const Expr& body) {
auto* n = make_shared<_LoweredFunc_>();
n->name = name;
n->args = args;
n->body = body;
return LoweredFunc(n);
}

LoweredFunc _LoweredFunc_::Make(const std::string& name,
const std::vector<Argument>& args,
const std::vector<Expr>& body) {
CHECK_EQ(body.size(), 1);
return Make(name, args, body.front());
}

std::vector<Expr*> _LoweredFunc_::expr_fields() { return {&body}; }
std::vector<const Expr*> _LoweredFunc_::expr_fields() const { return {&body}; }

} // namespace ir
} // namespace cinn
70 changes: 70 additions & 0 deletions cinn/ir/lowered_func.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#pragma once
#include "cinn/ir/buffer.h"
#include "cinn/ir/node.h"

namespace cinn {
namespace ir {

class _LoweredFunc_;

/**
* A struct representing an argument to a lowered function. Used for specifying the function signature of generated
* code.
*/
struct Argument {
//! The name of the argument.
std::string name;

enum class Kind { kScalar = 0, kBuffer } kind{Kind::kScalar};

//! Number of the dimensions of buffer.
uint32_t ndims{0};

//! The type of the buffer or scalar.
Type type;

bool is_buffer() const { return kind == Kind::kBuffer; }
bool is_scalar() const { return kind == Kind::kScalar; }

Argument() {}
Argument(const std::string& name, Kind kind, const Type& type, int ndims)
: name(name), kind(kind), type(type), ndims(ndims) {}

explicit Argument(const ir::Buffer& buffer) : name(buffer->name), type(buffer->type()), ndims(buffer->shape.size()) {}
};

//! Wrapper for _LoweredFunc_
class LoweredFunc : public IrNodeRef {
public:
LoweredFunc() = default;
explicit LoweredFunc(IrNode* n) : IrNodeRef(n) {}

const _LoweredFunc_* operator->() const;
_LoweredFunc_* operator->();
};

/**
* Definition of a lowered function. Note that, it should be functional.
*/
struct _LoweredFunc_ : ExprNode<_LoweredFunc_> {
//! The name of this function.
std::string name;

//! The Arguments used in the body of the function.
std::vector<Argument> args;

//! Body of this function.
Expr body;

static LoweredFunc Make(const std::string& name, const std::vector<Argument>& args, const Expr& body);

static LoweredFunc Make(const std::string& name, const std::vector<Argument>& args, const std::vector<Expr>& body);

std::vector<Expr*> expr_fields() override;
std::vector<const Expr*> expr_fields() const override;

static const IrNodeTy _node_type_ = IrNodeTy::_LoweredFunc_;
};

} // namespace ir
} // namespace cinn
1 change: 1 addition & 0 deletions cinn/ir/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class IRVisitor;
macro__(_IterVar_) \
macro__(_Buffer_) \
macro__(_Tensor_) \
macro__(_LoweredFunc_) \

#define NODETY_FORALL(__m) \
NODETY_PRIMITIVE_TYPE_FOR_EACH(__m) \
Expand Down
8 changes: 4 additions & 4 deletions cinn/lang/lower.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Expr LowerGroup(const poly::detail::Group& group, const std::map<std::string, Ex
return e;
}

std::vector<LoweredFunc> Lower(const std::string& name, const std::vector<Tensor>& args) {
std::vector<ir::LoweredFunc> Lower(const std::string& name, const std::vector<Tensor>& args) {
// make sure the graph's start-points in the args.

auto stages = poly::GatherStagesInTensors(args);
Expand Down Expand Up @@ -106,12 +106,12 @@ std::vector<LoweredFunc> Lower(const std::string& name, const std::vector<Tensor
optim::RemoveNestedBlock(&block);

// prepare arguments
std::vector<Argument> arguments;
std::vector<ir::Argument> arguments;
for (auto& arg : args) {
arguments.emplace_back(arg->name, Argument::Kind::kBuffer, arg->type(), arg->shape.size());
arguments.emplace_back(arg->name, ir::Argument::Kind::kBuffer, arg->type(), arg->shape.size());
}

return {LoweredFunc(name, arguments, block)};
return {ir::_LoweredFunc_::Make(name, arguments, block)};
}

} // namespace lang
Expand Down
2 changes: 1 addition & 1 deletion cinn/lang/lower.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace cinn {
namespace lang {
using ir::Tensor;

std::vector<LoweredFunc> Lower(const std::string& name, const std::vector<Tensor>& args);
std::vector<ir::LoweredFunc> Lower(const std::string& name, const std::vector<Tensor>& args);

} // namespace lang
} // namespace cinn
2 changes: 1 addition & 1 deletion cinn/lang/lower_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ TEST(lower, basic) {
}
}
)ROC";
TEST_SOUTPUT(lower_funcs.front().body, out);
TEST_SOUTPUT(lower_funcs.front()->body, out);
}

TEST(lower, more_complex) {
Expand Down
14 changes: 4 additions & 10 deletions cinn/lang/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ struct _Module_ : Object {
std::string name;
Target target;
std::vector<ir::Buffer> buffers;
std::vector<ir::PackedFunc> functions;
std::vector<ir::LoweredFunc> functions;
std::vector<Module> submodules;

const char *type_info() const override { return "_Module_"; }
Expand All @@ -30,23 +30,17 @@ const std::string &Module::name() const { return self()->name; }

const std::vector<ir::Buffer> &Module::buffers() const { return self()->buffers; }

const std::vector<ir::PackedFunc> &Module::functions() const { return self()->functions; }
const std::vector<ir::LoweredFunc> &Module::functions() const { return self()->functions; }

const std::vector<Module> &Module::submodules() const { return self()->submodules; }

void Module::Append(const ir::Buffer &buffer) { self()->buffers.push_back(buffer); }
void Module::Append(const Buffer &buffer) { self()->buffers.push_back(buffer.buffer()); }

void Module::Append(const ir::PackedFunc &function) { self()->functions.push_back(function); }
void Module::Append(const ir::LoweredFunc &function) { self()->functions.push_back(function); }

void Module::Append(const Module &module) { self()->submodules.push_back(module); }

void Module::Compile(const backends::Outputs &outputs) const {}

LoweredFunc::LoweredFunc(const std::string &name, const std::vector<Argument> &args, const std::vector<Expr> &body) {
this->name = name;
this->args = args;
this->body = ir::Block::Make(body);
}

} // namespace lang
} // namespace cinn
Loading

0 comments on commit 18a20ce

Please sign in to comment.