Skip to content

Commit

Permalink
Add support for the Module class. (PaddlePaddle#33)
Browse files Browse the repository at this point in the history
* Update return types of some methods in Instruction class and Function
class.

* Add methods for Module class, including ctor, ToProto, ToString, etc.
Add the unit test for Module.

* Prohibit copying and assignment of Module class.

* Add RawIterator support for instructions() and functions().

* Use the reference type for the range-for usage.

* Add more comments for Module and RawIterator.

* Fix some check errors and typos.

* Add PreconditionNotMet wrappers for error messages.

* Rename RawIterator by UnboxingIterator.

* Add `find_if` usage for func->instructions().
  • Loading branch information
wzzju authored Aug 23, 2021
1 parent 4510dff commit 988e2d8
Show file tree
Hide file tree
Showing 13 changed files with 539 additions and 120 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/compiler/piano/note/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ cc_test(note_opcode_test SRCS opcode_test.cc DEPS note_opcode)
proto_library(note_proto SRCS note.proto)
target_compile_options(note_proto PUBLIC "-Wno-extra")

cc_library(note_ir SRCS instruction.cc function.cc DEPS note_opcode note_proto piano_data_description)
cc_library(note_ir SRCS instruction.cc function.cc module.cc DEPS note_opcode note_proto piano_data_description)
cc_test(note_ir_test SRCS note_ir_test.cc DEPS note_ir)
19 changes: 17 additions & 2 deletions paddle/fluid/compiler/piano/note/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,27 @@ Function::Function(
instr->set_parent(this);
// set parameter(input) instructions field
if (instr->opcode() == OpCode::kParameter) {
PADDLE_ENFORCE_EQ(
instr->valid_parameter_index(), true,
platform::errors::PreconditionNotMet(
"The parameter instruction %s doesn't have a valid index.",
instr->name()));

param_instrs_.push_back(instr.get());
}
instr_index[instr_proto.id()] = instr.get();
inverted_index[instr.get()] = instr_proto.id();

auto instr_id = instr_proto.id();
PADDLE_ENFORCE_EQ(
instr_index.count(instr_id), 0,
platform::errors::PreconditionNotMet(
"The global id (%ld) of Instruction %s is the same as the previous "
"Instruction %s.",
instr_id, instr->name(), instr_index[instr_id]->name()));
instr_index[instr_id] = instr.get();
inverted_index[instr.get()] = instr_id;
instructions_.emplace_back(std::move(instr));
}

PADDLE_ENFORCE_EQ(
proto.return_id() >= 0 && instr_index.count(proto.return_id()), true,
platform::errors::PreconditionNotMet(
Expand Down
87 changes: 62 additions & 25 deletions paddle/fluid/compiler/piano/note/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,18 @@ limitations under the License. */
#include <memory>
#include <string>
#include <vector>
#include "boost/range/iterator_range.hpp"
#include "paddle/fluid/compiler/piano/note/instruction.h"
#include "paddle/fluid/compiler/piano/note/note.pb.h"
#include "paddle/fluid/compiler/piano/note/type_traits.h"
#include "paddle/fluid/compiler/piano/shape.h"
#include "paddle/fluid/platform/macros.h"

namespace paddle {
namespace piano {
namespace note {

class Instruction;
// class Module;
class Module;

class Function {
public:
Expand All @@ -46,47 +49,82 @@ class Function {
const std::string &name() const { return name_; }

// return instructions owned by this function
std::vector<Instruction *> instructions() const {
std::vector<Instruction *> instrs;
instrs.reserve(instructions_.size());
std::transform(
instructions_.cbegin(), instructions_.cend(),
std::back_inserter(instrs),
[](const std::unique_ptr<Instruction> &instr) { return instr.get(); });
return instrs;
// for(Instruction &instr : function->instructions()){...}
auto instructions() const {
using IteratorT = decltype(instructions_.cbegin());
return boost::make_iterator_range(
UnboxingIterator<IteratorT>{instructions_.cbegin()},
UnboxingIterator<IteratorT>{instructions_.cend()});
}

const Instruction *instruction(std::int64_t idx) const {
return instructions_.at(idx).get();
// return an instruction included in this function by the given index
Instruction *instruction(std::int64_t idx) const {
PADDLE_ENFORCE_EQ(
idx >= 0 && idx < static_cast<std::int64_t>(instructions_.size()), true,
platform::errors::PreconditionNotMet("Invalid index value %ld. Its "
"value should between 0(include) "
"and %zu(exclude).",
idx, instructions_.size()));
PADDLE_ENFORCE_NOT_NULL(
instructions_[idx].get(),
platform::errors::PreconditionNotMet(
"The instruction %ld should not be null.", idx));
return instructions_[idx].get();
}

Instruction *mutable_instruction(std::int64_t idx) {
return instructions_.at(idx).get();
}

// return the function signature
// return the immutable function signature
const Signature &signature() const { return signature_; }

// return the mutable function signature
Signature *mutable_signature() { return &signature_; }

// return the globally unique id of this function
std::int64_t global_id() const { return global_id_; }

// return the returned instruction of this function
const Instruction *return_instr() const { return return_instr_; }
const Instruction &return_instr() const {
PADDLE_ENFORCE_NOT_NULL(return_instr_,
platform::errors::PreconditionNotMet(
"The return instruction should not be null."));
return *return_instr_;
}

// const Module *parent() const { return parent_; }
// return the immutable module which includes this function
const Module &parent() const {
PADDLE_ENFORCE_NOT_NULL(parent_, platform::errors::PreconditionNotMet(
"The parent_(Module) of this function "
"is null, please set it first."));
return *parent_;
}

// Module *mutable_parent() { return parent_; }
// return the mutable module which includes this function
Module *mutable_parent() {
PADDLE_ENFORCE_NOT_NULL(parent_, platform::errors::PreconditionNotMet(
"The parent_(Module) of this function "
"is null, please set it first."));
return parent_;
}

// void set_parent(Module *module) { parent_ = module; }
// set the module in which this function resides
void set_parent(Module *mod) { parent_ = mod; }

const std::vector<Instruction *> &param_instrs() const {
return param_instrs_;
}

// return parameter instructions of this function
const Instruction *param_instr(std::int64_t idx) const {
return param_instrs_.at(idx);
const Instruction &param_instr(std::int64_t idx) const {
PADDLE_ENFORCE_EQ(
idx >= 0 && idx < static_cast<std::int64_t>(param_instrs_.size()), true,
platform::errors::PreconditionNotMet("Invalid index value %ld. Its "
"value should between 0(include) "
"and %zu(exclude).",
idx, param_instrs_.size()));
PADDLE_ENFORCE_NOT_NULL(
param_instrs_[idx],
platform::errors::PreconditionNotMet(
"The parameter instruction %ld should not be null.", idx));
return *param_instrs_[idx];
}

// return the parameter(input) number of this function
Expand All @@ -104,9 +142,8 @@ class Function {
// the returned instruction of this function
Instruction *return_instr_;

// TODO(wzzju): Add Module class.
// the module where this function is contained
// Module *parent_{nullptr};
Module *parent_{nullptr};

// parameter instructions of this function,
// which denote input parameters
Expand Down
21 changes: 8 additions & 13 deletions paddle/fluid/compiler/piano/note/instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Instruction::Instruction(

// add control dependency
for (auto id : proto.control_predecessor_ids()) {
PADDLE_ENFORCE_EQ(instr_index.at(id)->parent(), parent(),
PADDLE_ENFORCE_EQ(instr_index.at(id)->mutable_parent(), mutable_parent(),
platform::errors::PreconditionNotMet(
"The instruction and its dependent instruction are "
"not in the same function."));
Expand All @@ -68,16 +68,9 @@ Instruction::Instruction(
}
}

// set parameter number
if (proto.has_parameter_number()) {
PADDLE_ENFORCE_EQ(proto.parameter_number(), operands_.size(),
platform::errors::PreconditionNotMet(
"The number of operands(%ld) is not equal to the "
"parameter_number(%zu) in proto.",
proto.parameter_number(), operands_.size()));
parameter_number_ = proto.parameter_number();
} else {
parameter_number_ = static_cast<std::int64_t>(operands_.size());
// set parameter index
if (proto.has_parameter_index()) {
parameter_index_ = proto.parameter_index();
}

// set attrs
Expand All @@ -94,7 +87,9 @@ InstructionProto Instruction::ToProto() const {
proto.set_name(name_);
proto.set_opcode(GetOpName(opcode_));
proto.set_id(global_id_);
proto.set_parameter_number(parameter_number_);
if (valid_parameter_index()) {
proto.set_parameter_index(parameter_index_);
}

// serialize shape info
*proto.mutable_shape() = shape_.ToProto();
Expand Down Expand Up @@ -166,7 +161,7 @@ std::string Instruction::ToString() const {
string::join_strings(attr_strs, ", ").c_str());
}

void Instruction::Accept(backends::NoteVisitorBase* visitor) {
void Instruction::Accept(backends::NoteVisitorBase* visitor) const {
switch (opcode_) {
#define HANDLE_VISIT(enum_id, op_name, ...) \
case OpCode::k##enum_id: \
Expand Down
63 changes: 49 additions & 14 deletions paddle/fluid/compiler/piano/note/instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,48 +50,79 @@ class Instruction {

std::string ToString() const;

void Accept(backends::NoteVisitorBase *visitor);

void Accept(backends::NoteVisitorBase *visitor) const {
return const_cast<Instruction *>(this)->Accept(visitor);
}
void Accept(backends::NoteVisitorBase *visitor) const;

// return the name of this instruction
const std::string &name() const { return name_; }

// return the opcode of this instruction
OpCode opcode() const { return opcode_; }

// return the immutable result shape of this instruction
const Shape &shape() const { return shape_; }

// return the mutable result shape of this instruction
Shape *mutable_shape() { return &shape_; }

const Function *parent() const { return parent_; }
// return the immutable function which includes this instruction
const Function &parent() const {
PADDLE_ENFORCE_NOT_NULL(parent_,
platform::errors::PreconditionNotMet(
"The parent_(Function) of this instruction is "
"null, please set it first."));
return *parent_;
}

Function *mutable_parent() { return parent_; }
// return the mutable function which includes this instruction
Function *mutable_parent() {
PADDLE_ENFORCE_NOT_NULL(parent_,
platform::errors::PreconditionNotMet(
"The parent_(Function) of this instruction is "
"null, please set it first."));
return parent_;
}

// set the function in which this instruction resides
void set_parent(Function *func) { parent_ = func; }

// return the globally unique id of this instruction
std::int64_t global_id() const { return global_id_; }

// return instruction operands
const std::vector<Instruction *> &operands() const { return operands_; }

const Instruction *operand(std::int64_t idx) const {
return operands_.at(idx);
const Instruction &operand(std::int64_t idx) const {
PADDLE_ENFORCE_EQ(
idx >= 0 && idx < static_cast<std::int64_t>(operands_.size()), true,
platform::errors::PreconditionNotMet("Invalid index value %ld. Its "
"value should between 0(include) "
"and %zu(exclude).",
idx, operands_.size()));
PADDLE_ENFORCE_NOT_NULL(operands_[idx],
platform::errors::PreconditionNotMet(
"operand %ld should not be null.", idx));
return *operands_[idx];
}

Instruction *mutable_operand(std::int64_t idx) const {
Instruction *mutable_operand(std::int64_t idx) {
PADDLE_ENFORCE_EQ(
idx >= 0 && idx < static_cast<std::int64_t>(operands_.size()), true,
platform::errors::PreconditionNotMet("Invalid index value %ld. Its "
"value should between 0(include) "
"and %zu(exclude).",
idx, operands_.size()));
PADDLE_ENFORCE_NOT_NULL(operands_[idx],
platform::errors::PreconditionNotMet(
"operand %ld should not be null.", idx));
return operands_.at(idx);
return operands_[idx];
}

// return the control predecessors of this instruction
const std::vector<Instruction *> &ctrl_predecessors() const {
return ctrl_predecessors_;
}

// return the control successors of this instruction
const std::vector<Instruction *> &ctrl_successors() const {
return ctrl_successors_;
}
Expand All @@ -101,7 +132,11 @@ class Instruction {
return call_functions_;
}

std::int64_t parameter_number() const { return parameter_number_; }
// return the input index of this instruction
std::int64_t parameter_index() const { return parameter_index_; }

// only the Parameter instruction has a valid parameter index
bool valid_parameter_index() const { return parameter_index_ != -1; }

// return attributes of this instruction
const MapType &attrs() const { return attrs_; }
Expand Down Expand Up @@ -148,8 +183,8 @@ class Instruction {
std::vector<Instruction *> ctrl_successors_;
// functions called directly by this instruction
std::vector<Function *> call_functions_;
// the parameter number of this instruction
std::int64_t parameter_number_;
// the input index of this instruction
std::int64_t parameter_index_{-1};
// attributes belongs to this instruction
MapType attrs_;
// the function where this instruction is contained
Expand Down
Loading

0 comments on commit 988e2d8

Please sign in to comment.