forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request PaddlePaddle#48 from Superjomn/fea/make-lowered_fu…
…nc-ir fea/make lowered func ir
- Loading branch information
Showing
18 changed files
with
261 additions
and
161 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,4 @@ | |
cmake-build* | ||
build* | ||
.idea* | ||
*.html |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#include "cinn/ir/lowered_func.h" | ||
|
||
#include "cinn/common/common.h" | ||
|
||
namespace cinn { | ||
namespace ir { | ||
|
||
const _LoweredFunc_* LoweredFunc::operator->() const { return As<_LoweredFunc_>(); } | ||
_LoweredFunc_* LoweredFunc::operator->() { return As<_LoweredFunc_>(); } | ||
|
||
LoweredFunc _LoweredFunc_::Make(const std::string& name, const std::vector<Argument>& args, const Expr& body) { | ||
auto* n = make_shared<_LoweredFunc_>(); | ||
n->name = name; | ||
n->args = args; | ||
n->body = body; | ||
return LoweredFunc(n); | ||
} | ||
|
||
LoweredFunc _LoweredFunc_::Make(const std::string& name, | ||
const std::vector<Argument>& args, | ||
const std::vector<Expr>& body) { | ||
CHECK_EQ(body.size(), 1); | ||
return Make(name, args, body.front()); | ||
} | ||
|
||
std::vector<Expr*> _LoweredFunc_::expr_fields() { return {&body}; } | ||
std::vector<const Expr*> _LoweredFunc_::expr_fields() const { return {&body}; } | ||
|
||
} // namespace ir | ||
} // namespace cinn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
#pragma once | ||
#include "cinn/ir/buffer.h" | ||
#include "cinn/ir/node.h" | ||
|
||
namespace cinn { | ||
namespace ir { | ||
|
||
class _LoweredFunc_; | ||
|
||
/** | ||
* A struct representing an argument to a lowered function. Used for specifying the function signature of generated | ||
* code. | ||
*/ | ||
struct Argument { | ||
//! The name of the argument. | ||
std::string name; | ||
|
||
enum class Kind { kScalar = 0, kBuffer } kind{Kind::kScalar}; | ||
|
||
//! Number of the dimensions of buffer. | ||
uint32_t ndims{0}; | ||
|
||
//! The type of the buffer or scalar. | ||
Type type; | ||
|
||
bool is_buffer() const { return kind == Kind::kBuffer; } | ||
bool is_scalar() const { return kind == Kind::kScalar; } | ||
|
||
Argument() {} | ||
Argument(const std::string& name, Kind kind, const Type& type, int ndims) | ||
: name(name), kind(kind), type(type), ndims(ndims) {} | ||
|
||
explicit Argument(const ir::Buffer& buffer) : name(buffer->name), type(buffer->type()), ndims(buffer->shape.size()) {} | ||
}; | ||
|
||
//! Wrapper for _LoweredFunc_ | ||
class LoweredFunc : public IrNodeRef { | ||
public: | ||
LoweredFunc() = default; | ||
explicit LoweredFunc(IrNode* n) : IrNodeRef(n) {} | ||
|
||
const _LoweredFunc_* operator->() const; | ||
_LoweredFunc_* operator->(); | ||
}; | ||
|
||
/** | ||
* Definition of a lowered function. Note that, it should be functional. | ||
*/ | ||
struct _LoweredFunc_ : ExprNode<_LoweredFunc_> { | ||
//! The name of this function. | ||
std::string name; | ||
|
||
//! The Arguments used in the body of the function. | ||
std::vector<Argument> args; | ||
|
||
//! Body of this function. | ||
Expr body; | ||
|
||
static LoweredFunc Make(const std::string& name, const std::vector<Argument>& args, const Expr& body); | ||
|
||
static LoweredFunc Make(const std::string& name, const std::vector<Argument>& args, const std::vector<Expr>& body); | ||
|
||
std::vector<Expr*> expr_fields() override; | ||
std::vector<const Expr*> expr_fields() const override; | ||
|
||
static const IrNodeTy _node_type_ = IrNodeTy::_LoweredFunc_; | ||
}; | ||
|
||
} // namespace ir | ||
} // namespace cinn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.