forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
note ir builder and common operand (PaddlePaddle#12)
* initial note ir builder and operand * update operand construct and add unit test * enhance shape to support note_builder and add unit test * update NoteBuilder Build and Append implement and add unit test * fix comment * remove comment code * update comment and servel name rule * fix ut
- Loading branch information
Showing
9 changed files
with
483 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/compiler/piano/note_builder.h" | ||
#include <utility> | ||
#include "paddle/fluid/compiler/piano/note/note.pb.h" | ||
#include "paddle/fluid/string/string_helper.h" | ||
|
||
namespace paddle { | ||
namespace piano { | ||
|
||
std::string NameConcatId(const std::string& name, int64_t id, | ||
char delim = '.') { | ||
std::vector<std::string> strs({name, std::to_string(id)}); | ||
return paddle::string::join_strings(strs, delim); | ||
} | ||
|
||
const Operand::ShapeType& Operand::Shape() { | ||
return Builder()->GetShape(*this); | ||
} | ||
|
||
Operand NoteBuilder::AppendInstruction(note::InstructionProto&& instr, | ||
note::OpCode opcode, | ||
const std::vector<Operand>& operands) { | ||
// check the precondition of sevel special instructions | ||
if (opcode == note::OpCode::kParameter) { | ||
PADDLE_ENFORCE_EQ( | ||
instr.has_parameter_number(), true, | ||
platform::errors::PreconditionNotMet( | ||
"Parameter instruction shoule fill parameter_number field to " | ||
"indicate which parameter to be retrieved.")); | ||
|
||
const auto& index = instr.parameter_number(); | ||
PADDLE_ENFORCE_EQ(parameter_numbers_.count(index), 0, | ||
platform::errors::AlreadyExists( | ||
"Parameter[%d] already registered", index)); | ||
parameter_numbers_.insert(index); | ||
} | ||
|
||
instr.set_id(GetNextId()); | ||
instr.set_opcode(GetOpName(opcode)); | ||
if (instr.name().empty()) { | ||
instr.set_name(instr.opcode()); | ||
} | ||
|
||
for (const auto& op : operands) { | ||
PADDLE_ENFORCE_NOT_NULL( | ||
op.Builder(), | ||
platform::errors::InvalidArgument( | ||
"Invalid Operand[%d] because its builder is nullptr", op.Id())); | ||
PADDLE_ENFORCE_EQ(op.Builder(), this, | ||
platform::errors::InvalidArgument( | ||
"Operand builder_[%s] not consistent with the " | ||
"one[%s] of this instruction", | ||
op.Builder()->name(), this->name())); | ||
instr.add_operand_ids(op.Id()); | ||
} | ||
|
||
id2index_[instr.id()] = instructions_.size(); | ||
instructions_.emplace_back(std::move(instr)); | ||
instruction_shapes_.emplace_back(Shape(instructions_.back().shape())); | ||
return {instructions_.back().id(), this}; | ||
} | ||
|
||
const Shape& NoteBuilder::GetShape(Operand op) const { | ||
PADDLE_ENFORCE_EQ(op.Builder(), this, | ||
platform::errors::InvalidArgument( | ||
"Operand[%d] not belongs to this builder", op.Id())); | ||
PADDLE_ENFORCE_GT(id2index_.count(op.Id()), 0, | ||
platform::errors::NotFound( | ||
"Not found Operand[%d] on this builder", op.Id())); | ||
return instruction_shapes_[id2index_.at(op.Id())]; | ||
} | ||
|
||
Signature NoteBuilder::BuildSignature() const { | ||
Signature signature; | ||
// by default, the last instruction is root | ||
*signature.mutable_result() = Shape(instructions_.back().shape()); | ||
|
||
signature.mutable_parameters()->resize(parameter_numbers_.size()); | ||
signature.mutable_parameter_names()->resize(parameter_numbers_.size()); | ||
for (const auto& instr : instructions_) { | ||
static const auto parameter_opcode_name = | ||
note::GetOpName(note::OpCode::kParameter); | ||
if (instr.opcode() == parameter_opcode_name) { | ||
const auto& index = instr.parameter_number(); | ||
// this enforce will ensure the retrieved indexes of kParameter | ||
// are continuous from 0 to the size; | ||
PADDLE_ENFORCE_EQ( | ||
index >= 0 && index < parameter_numbers_.size(), true, | ||
platform::errors::OutOfRange("parameter number not in range[0, %lld]", | ||
parameter_numbers_.size())); | ||
|
||
signature.mutable_parameters()->at(index) = Shape(instr.shape()); | ||
signature.mutable_parameter_names()->at(index) = | ||
NameConcatId(instr.name(), instr.id()); | ||
} | ||
} | ||
|
||
return signature; | ||
} | ||
|
||
note::ModuleProto NoteBuilder::Build() { | ||
PADDLE_ENFORCE_NE(instructions_.empty(), true, | ||
platform::errors::PreconditionNotMet( | ||
"Can not build note::ModuleProto without instruction")); | ||
|
||
note::FunctionProto entry_function; | ||
entry_function.set_id(GetNextId()); | ||
entry_function.set_name(NameConcatId(this->name(), entry_function.id())); | ||
// by default, the last instruction is root | ||
entry_function.set_return_id(instructions_.back().id()); | ||
*entry_function.mutable_signature() = BuildSignature().ToProto(); | ||
for (auto& instruction : instructions_) { | ||
instruction.set_name(NameConcatId(instruction.name(), instruction.id())); | ||
// after building done all data will be cleared, | ||
// so just take the origin instruction here. | ||
entry_function.add_instructions()->Swap(&instruction); | ||
} | ||
|
||
note::ModuleProto note_module; | ||
note_module.set_id(entry_function.id()); | ||
note_module.set_entry_function_id(entry_function.id()); | ||
note_module.set_name(entry_function.name()); | ||
note_module.set_entry_function_name(entry_function.name()); | ||
note_module.mutable_entry_function_signature()->CopyFrom( | ||
entry_function.signature()); | ||
// take the origin entry_function directly | ||
note_module.add_functions()->Swap(&entry_function); | ||
|
||
// Clear data held by this builder. | ||
this->instructions_.clear(); | ||
this->instruction_shapes_.clear(); | ||
this->id2index_.clear(); | ||
this->parameter_numbers_.clear(); | ||
|
||
return note_module; | ||
} | ||
|
||
} // namespace piano | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include <string> | ||
#include <unordered_map> | ||
#include <unordered_set> | ||
#include <vector> | ||
#include "glog/logging.h" | ||
#include "paddle/fluid/compiler/piano/note/opcode.h" | ||
#include "paddle/fluid/compiler/piano/shape.h" | ||
#include "paddle/fluid/platform/enforce.h" | ||
#include "paddle/fluid/platform/errors.h" | ||
|
||
namespace paddle { | ||
namespace piano { | ||
|
||
class NoteBuilder; | ||
|
||
// A Operand is generally constructed by a NoteBuilder, as the returned value of | ||
// an Instruction, and can be used as an operand for succeeding instructions | ||
class Operand { | ||
public: | ||
using ShapeType = paddle::piano::Shape; | ||
|
||
Operand() : instr_id_(-1), builder_(nullptr) {} | ||
|
||
// whether this operand valid | ||
bool Valid() const { return instr_id_ >= 0 && builder_ != nullptr; } | ||
|
||
// the builder that contructs this operand | ||
NoteBuilder* Builder() const { | ||
PADDLE_ENFORCE_NOT_NULL( | ||
builder_, platform::errors::InvalidArgument("Builder is nullptr")); | ||
return builder_; | ||
} | ||
|
||
// shape of this operand | ||
const ShapeType& Shape(); | ||
|
||
private: | ||
// declare the folloing methods as private and only can be used from friend | ||
// class, to prevent from illegal usage in complie-time | ||
explicit Operand(NoteBuilder* builder) : instr_id_(-1), builder_(builder) {} | ||
Operand(int64_t id, NoteBuilder* builder) | ||
: instr_id_(id), builder_(builder) {} | ||
|
||
int64_t Id() const { return instr_id_; } | ||
|
||
friend class NoteBuilder; | ||
|
||
private: | ||
// the unique id that denotes which instruction generate this value | ||
int64_t instr_id_; | ||
// the builder that holds the instruction | ||
NoteBuilder* builder_; | ||
}; | ||
|
||
// A NoteBuilder keeps a list of instructions within the same Note Module, | ||
// and user can append new instructions with one or more operands which come | ||
// from instructions enqueued | ||
// | ||
// This is used as a convenient interface for building up the initial Note | ||
// Module. | ||
class NoteBuilder { | ||
public: | ||
explicit NoteBuilder(const std::string& name) : name_(name) {} | ||
|
||
// Append an new instruction | ||
Operand AppendInstruction(note::InstructionProto&& instr, note::OpCode opcode, | ||
const std::vector<Operand>& operands); | ||
|
||
// Returns the shape of the given operand. | ||
const Shape& GetShape(Operand op) const; | ||
|
||
// Build the init note::ModuleProto with an entry function | ||
// which includes all instructions | ||
note::ModuleProto Build(); | ||
|
||
// name of this builder | ||
const std::string& name() { return name_; } | ||
|
||
private: | ||
// Generate the next sequential id | ||
int64_t GetNextId() { return ++next_id_; } | ||
|
||
// Build the signature of entry function | ||
Signature BuildSignature() const; | ||
|
||
private: | ||
// Name to use for the built note::ModuleProto | ||
std::string name_; | ||
|
||
// The next sequential ID for every instruction contained within this builer. | ||
int64_t next_id_ = 0; | ||
|
||
// The instructions list | ||
std::vector<note::InstructionProto> instructions_; | ||
|
||
// The shape list of appended instructions | ||
std::vector<Shape> instruction_shapes_; | ||
|
||
// A map from ID to the index in the instructions_ vector where the | ||
// instruction resides in. | ||
std::unordered_map<int64_t, int64_t> id2index_; | ||
|
||
// The unique parameter numbers. | ||
std::unordered_set<int64_t> parameter_numbers_; | ||
}; | ||
|
||
} // namespace piano | ||
} // namespace paddle |
Oops, something went wrong.