Skip to content

Commit

Permalink
note ir builder and common operand (PaddlePaddle#12)
Browse files Browse the repository at this point in the history
* initial note ir builder and operand

* update operand construct and add unit test

* enhance shape to support note_builder and add unit test

* update NoteBuilder Build and Append implement and add unit test

* fix comment

* remove comment code

* update comment and servel name rule

* fix ut
  • Loading branch information
CtfGo authored Aug 17, 2021
1 parent f6c3b77 commit 9afbf9a
Show file tree
Hide file tree
Showing 9 changed files with 483 additions and 2 deletions.
7 changes: 5 additions & 2 deletions paddle/fluid/compiler/piano/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,8 @@ add_subdirectory(backends)
add_subdirectory(note)

cc_library(piano_data_description SRCS layout.cc shape.cc DEPS string_helper note_proto)
cc_test(layout_test SRCS layout_test.cc DEPS piano_data_description)
cc_test(shape_test SRCS shape_test.cc DEPS piano_data_description)
cc_test(piano_layout_test SRCS layout_test.cc DEPS piano_data_description)
cc_test(piano_shape_test SRCS shape_test.cc DEPS piano_data_description)

cc_library(note_builder SRCS note_builder.cc DEPS string_helper note_opcode piano_data_description)
cc_test(note_builder_test SRCS note_builder_test.cc DEPS note_builder)
8 changes: 8 additions & 0 deletions paddle/fluid/compiler/piano/layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,13 @@ std::string Layout::ToString() const {

bool Layout::Valid() const { return !minor_to_major().empty(); }

bool Layout::operator==(const Layout& other) const {
if (minor_to_major() != other.minor_to_major()) {
return false;
}

return true;
}

} // namespace piano
} // namespace paddle
4 changes: 4 additions & 0 deletions paddle/fluid/compiler/piano/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class Layout {
// Return whether this layout is valid
bool Valid() const;

// Return whether this layout euqal to the other
bool operator==(const Layout& other) const;
bool operator!=(const Layout& other) const { return !(*this == other); }

// The following methods for accessing the data member of a Layout object
// stores.
//
Expand Down
152 changes: 152 additions & 0 deletions paddle/fluid/compiler/piano/note_builder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/compiler/piano/note_builder.h"
#include <utility>
#include "paddle/fluid/compiler/piano/note/note.pb.h"
#include "paddle/fluid/string/string_helper.h"

namespace paddle {
namespace piano {

std::string NameConcatId(const std::string& name, int64_t id,
char delim = '.') {
std::vector<std::string> strs({name, std::to_string(id)});
return paddle::string::join_strings(strs, delim);
}

const Operand::ShapeType& Operand::Shape() {
return Builder()->GetShape(*this);
}

Operand NoteBuilder::AppendInstruction(note::InstructionProto&& instr,
note::OpCode opcode,
const std::vector<Operand>& operands) {
// check the precondition of sevel special instructions
if (opcode == note::OpCode::kParameter) {
PADDLE_ENFORCE_EQ(
instr.has_parameter_number(), true,
platform::errors::PreconditionNotMet(
"Parameter instruction shoule fill parameter_number field to "
"indicate which parameter to be retrieved."));

const auto& index = instr.parameter_number();
PADDLE_ENFORCE_EQ(parameter_numbers_.count(index), 0,
platform::errors::AlreadyExists(
"Parameter[%d] already registered", index));
parameter_numbers_.insert(index);
}

instr.set_id(GetNextId());
instr.set_opcode(GetOpName(opcode));
if (instr.name().empty()) {
instr.set_name(instr.opcode());
}

for (const auto& op : operands) {
PADDLE_ENFORCE_NOT_NULL(
op.Builder(),
platform::errors::InvalidArgument(
"Invalid Operand[%d] because its builder is nullptr", op.Id()));
PADDLE_ENFORCE_EQ(op.Builder(), this,
platform::errors::InvalidArgument(
"Operand builder_[%s] not consistent with the "
"one[%s] of this instruction",
op.Builder()->name(), this->name()));
instr.add_operand_ids(op.Id());
}

id2index_[instr.id()] = instructions_.size();
instructions_.emplace_back(std::move(instr));
instruction_shapes_.emplace_back(Shape(instructions_.back().shape()));
return {instructions_.back().id(), this};
}

const Shape& NoteBuilder::GetShape(Operand op) const {
PADDLE_ENFORCE_EQ(op.Builder(), this,
platform::errors::InvalidArgument(
"Operand[%d] not belongs to this builder", op.Id()));
PADDLE_ENFORCE_GT(id2index_.count(op.Id()), 0,
platform::errors::NotFound(
"Not found Operand[%d] on this builder", op.Id()));
return instruction_shapes_[id2index_.at(op.Id())];
}

Signature NoteBuilder::BuildSignature() const {
Signature signature;
// by default, the last instruction is root
*signature.mutable_result() = Shape(instructions_.back().shape());

signature.mutable_parameters()->resize(parameter_numbers_.size());
signature.mutable_parameter_names()->resize(parameter_numbers_.size());
for (const auto& instr : instructions_) {
static const auto parameter_opcode_name =
note::GetOpName(note::OpCode::kParameter);
if (instr.opcode() == parameter_opcode_name) {
const auto& index = instr.parameter_number();
// this enforce will ensure the retrieved indexes of kParameter
// are continuous from 0 to the size;
PADDLE_ENFORCE_EQ(
index >= 0 && index < parameter_numbers_.size(), true,
platform::errors::OutOfRange("parameter number not in range[0, %lld]",
parameter_numbers_.size()));

signature.mutable_parameters()->at(index) = Shape(instr.shape());
signature.mutable_parameter_names()->at(index) =
NameConcatId(instr.name(), instr.id());
}
}

return signature;
}

note::ModuleProto NoteBuilder::Build() {
PADDLE_ENFORCE_NE(instructions_.empty(), true,
platform::errors::PreconditionNotMet(
"Can not build note::ModuleProto without instruction"));

note::FunctionProto entry_function;
entry_function.set_id(GetNextId());
entry_function.set_name(NameConcatId(this->name(), entry_function.id()));
// by default, the last instruction is root
entry_function.set_return_id(instructions_.back().id());
*entry_function.mutable_signature() = BuildSignature().ToProto();
for (auto& instruction : instructions_) {
instruction.set_name(NameConcatId(instruction.name(), instruction.id()));
// after building done all data will be cleared,
// so just take the origin instruction here.
entry_function.add_instructions()->Swap(&instruction);
}

note::ModuleProto note_module;
note_module.set_id(entry_function.id());
note_module.set_entry_function_id(entry_function.id());
note_module.set_name(entry_function.name());
note_module.set_entry_function_name(entry_function.name());
note_module.mutable_entry_function_signature()->CopyFrom(
entry_function.signature());
// take the origin entry_function directly
note_module.add_functions()->Swap(&entry_function);

// Clear data held by this builder.
this->instructions_.clear();
this->instruction_shapes_.clear();
this->id2index_.clear();
this->parameter_numbers_.clear();

return note_module;
}

} // namespace piano
} // namespace paddle
124 changes: 124 additions & 0 deletions paddle/fluid/compiler/piano/note_builder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/compiler/piano/note/opcode.h"
#include "paddle/fluid/compiler/piano/shape.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"

namespace paddle {
namespace piano {

class NoteBuilder;

// A Operand is generally constructed by a NoteBuilder, as the returned value of
// an Instruction, and can be used as an operand for succeeding instructions
class Operand {
public:
using ShapeType = paddle::piano::Shape;

Operand() : instr_id_(-1), builder_(nullptr) {}

// whether this operand valid
bool Valid() const { return instr_id_ >= 0 && builder_ != nullptr; }

// the builder that contructs this operand
NoteBuilder* Builder() const {
PADDLE_ENFORCE_NOT_NULL(
builder_, platform::errors::InvalidArgument("Builder is nullptr"));
return builder_;
}

// shape of this operand
const ShapeType& Shape();

private:
// declare the folloing methods as private and only can be used from friend
// class, to prevent from illegal usage in complie-time
explicit Operand(NoteBuilder* builder) : instr_id_(-1), builder_(builder) {}
Operand(int64_t id, NoteBuilder* builder)
: instr_id_(id), builder_(builder) {}

int64_t Id() const { return instr_id_; }

friend class NoteBuilder;

private:
// the unique id that denotes which instruction generate this value
int64_t instr_id_;
// the builder that holds the instruction
NoteBuilder* builder_;
};

// A NoteBuilder keeps a list of instructions within the same Note Module,
// and user can append new instructions with one or more operands which come
// from instructions enqueued
//
// This is used as a convenient interface for building up the initial Note
// Module.
class NoteBuilder {
public:
explicit NoteBuilder(const std::string& name) : name_(name) {}

// Append an new instruction
Operand AppendInstruction(note::InstructionProto&& instr, note::OpCode opcode,
const std::vector<Operand>& operands);

// Returns the shape of the given operand.
const Shape& GetShape(Operand op) const;

// Build the init note::ModuleProto with an entry function
// which includes all instructions
note::ModuleProto Build();

// name of this builder
const std::string& name() { return name_; }

private:
// Generate the next sequential id
int64_t GetNextId() { return ++next_id_; }

// Build the signature of entry function
Signature BuildSignature() const;

private:
// Name to use for the built note::ModuleProto
std::string name_;

// The next sequential ID for every instruction contained within this builer.
int64_t next_id_ = 0;

// The instructions list
std::vector<note::InstructionProto> instructions_;

// The shape list of appended instructions
std::vector<Shape> instruction_shapes_;

// A map from ID to the index in the instructions_ vector where the
// instruction resides in.
std::unordered_map<int64_t, int64_t> id2index_;

// The unique parameter numbers.
std::unordered_set<int64_t> parameter_numbers_;
};

} // namespace piano
} // namespace paddle
Loading

0 comments on commit 9afbf9a

Please sign in to comment.