Skip to content

Commit

Permalink
Exposing Channel to be used as a Variable and integrating with Fluid (#…
Browse files Browse the repository at this point in the history
…8486)

* Adding set_capacity method support

* Adding Python for make_channel

* Updating notest_concurrency

* Write python for make_channel method

* Write python for make_channel method

* Fix make_channel and test

* Placeholder ops for channel send, recv and close

* Adding ToTypeIndex method to var_type.h

* Add var_type.h to channel:

* Added POD_Type to the method

* Add CHANNEL to executor

* Updated get and set DataType to accomodate Channels

* Updating get and set to incorporate channels

* Adding CHANNEL as supported VarType in protobuf

* Removing unecessary import

* Fixing VarDesc to adapt to Channel as VarType

* Add channel.h to executor

* Remove innclude from channel

* Updated var_type to support Channel as  var type

* Adding get_channel to pybind

* Added ChannelHolder

* Adding make_channel as an op

* Adding ChannelHolder in channel

* Fixing typo

* Commenting out operators in concurrency

* Removing totypeid right now since we don't need it.

* Reverting python changes

* Fixing typo in framework.py

* Modify comments for ReaderHolder
  • Loading branch information
kavyasrinet authored Feb 23, 2018
1 parent 88c22e9 commit 77ee8fb
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 3 deletions.
73 changes: 73 additions & 0 deletions paddle/fluid/framework/channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License. */
#pragma once

#include <stddef.h> // for size_t
#include <typeindex>
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -51,6 +53,77 @@ void CloseChannel(Channel<T>* ch) {
ch->Close();
}

/*
* The ChannelHolder class serves two main purposes:
* 1. It acts as a unified wrapper for the different kinds of
* channels, i.e. Buffered and Unbuffered channels. This is
* similar to the ReaderHolder class.
* 2. It also helps us in TypeHiding. This is similar to the
* PlaceHolder implementations in variable.h and tensor.h.
*/
class ChannelHolder {
public:
template <typename T>
void Reset(size_t buffer_size) {
holder_.reset(new PlaceholderImpl<T>(buffer_size));
}

template <typename T>
bool Send(T* data) {
if (!IsInitialized()) return false;
PADDLE_ENFORCE_EQ(holder_->Type(), std::type_index(typeid(T)));
// Static cast should be safe because we have ensured that types are same
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
return channel != nullptr ? channel->Send(data) : false;
}

template <typename T>
bool Receive(T* data) {
if (!IsInitialized()) return false;
PADDLE_ENFORCE_EQ(holder_->Type(), std::type_index(typeid(T)));
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
return channel != nullptr ? channel->Receive(data) : false;
}

void close() {
if (IsInitialized()) holder_->Close();
}

inline bool IsInitialized() const { return holder_ != nullptr; }

private:
/**
* @note Placeholder hides type T, so it doesn't appear as a template
* parameter of ChannelHolder.
*/
struct Placeholder {
virtual ~Placeholder() {}
virtual const std::type_index Type() const = 0;
virtual void* Ptr() const = 0;
virtual void Close() const = 0;
std::type_info type_;
};

template <typename T>
struct PlaceholderImpl : public Placeholder {
PlaceholderImpl(size_t buffer_size) : type_(std::type_index(typeid(T))) {
channel_.reset(MakeChannel<T>(buffer_size));
}

virtual const std::type_index Type() const { return type_; }
virtual void* Ptr() const { return static_cast<void*>(channel_.get()); }
virtual void Close() {
if (channel_) channel_->Close();
}

std::unique_ptr<Channel<T>*> channel_;
const std::type_index type_;
};

// Pointer to a PlaceholderImpl object
std::unique_ptr<Placeholder> holder_;
};

} // namespace framework
} // namespace paddle

Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#include <set>

#include "gflags/gflags.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/lod_rank_table.h"
Expand Down Expand Up @@ -55,13 +56,15 @@ static void CreateTensor(Variable* var, proto::VarType::Type var_type) {
var->GetMutable<platform::PlaceList>();
} else if (var_type == proto::VarType::READER) {
var->GetMutable<ReaderHolder>();
} else if (var_type == proto::VarType::CHANNEL) {
var->GetMutable<ChannelHolder>();
} else if (var_type == proto::VarType::NCCL_COM) {
// GetMutable will be called in ncclInit
} else {
PADDLE_THROW(
"Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER, NCCL_COM]",
"LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, NCCL_COM]",
var_type);
}
}
Expand Down
54 changes: 52 additions & 2 deletions paddle/fluid/framework/var_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,13 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
}

void VarDesc::SetDataType(proto::VarType::Type data_type) {
mutable_tensor_desc()->set_data_type(data_type);
switch (desc_.type().type()) {
case proto::VarType::CHANNEL:
mutable_channel_desc()->set_data_type(data_type);
break;
default:
mutable_tensor_desc()->set_data_type(data_type);
}
}

void VarDesc::SetDataTypes(
Expand All @@ -109,7 +115,13 @@ void VarDesc::SetDataTypes(
}

proto::VarType::Type VarDesc::GetDataType() const {
return tensor_desc().data_type();
switch (desc_.type().type()) {
case proto::VarType::CHANNEL:
return channel_desc().data_type();
break;
default:
return tensor_desc().data_type();
}
}

std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
Expand All @@ -122,6 +134,17 @@ std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
return res;
}

void VarDesc::SetCapacity(int64_t capacity) {
switch (desc_.type().type()) {
case proto::VarType::CHANNEL:
desc_.mutable_type()->mutable_channel()->set_capacity(capacity);
break;
default:
PADDLE_THROW("Setting 'capacity' is not supported by the type of var %s.",
this->Name());
}
}

void VarDesc::SetLoDLevel(int32_t lod_level) {
switch (desc_.type().type()) {
case proto::VarType::LOD_TENSOR:
Expand Down Expand Up @@ -191,6 +214,19 @@ std::vector<int32_t> VarDesc::GetLoDLevels() const {
}
}

const proto::VarType::ChannelDesc &VarDesc::channel_desc() const {
PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
switch (desc_.type().type()) {
case proto::VarType::CHANNEL:
return desc_.type().channel();
default:
PADDLE_THROW(
"Getting 'channel_desc' is not supported by the type of var %s.",
this->Name());
}
}

const proto::VarType::TensorDesc &VarDesc::tensor_desc() const {
PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
Expand Down Expand Up @@ -226,6 +262,20 @@ std::vector<proto::VarType::TensorDesc> VarDesc::tensor_descs() const {
}
}

proto::VarType::ChannelDesc *VarDesc::mutable_channel_desc() {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
switch (desc_.type().type()) {
case proto::VarType::CHANNEL:
return desc_.mutable_type()->mutable_channel();
default:
PADDLE_THROW(
"Getting 'mutable_channel_desc' is not supported by the type of var "
"%s.",
this->Name());
}
}

proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/var_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class VarDesc {
void SetDataTypes(
const std::vector<proto::VarType::Type> &multiple_data_type);

void SetCapacity(int64_t capacity);

proto::VarType::Type GetDataType() const;

std::vector<proto::VarType::Type> GetDataTypes() const;
Expand All @@ -106,8 +108,10 @@ class VarDesc {
void SetPersistable(bool persistable) { desc_.set_persistable(persistable); }

private:
const proto::VarType::ChannelDesc &channel_desc() const;
const proto::VarType::TensorDesc &tensor_desc() const;
std::vector<proto::VarType::TensorDesc> tensor_descs() const;
proto::VarType::ChannelDesc *mutable_channel_desc();
proto::VarType::TensorDesc *mutable_tensor_desc();
std::vector<proto::VarType::TensorDesc *> mutable_tensor_descs();

Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/framework/var_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h"
Expand All @@ -34,6 +35,8 @@ inline proto::VarType::Type ToVarType(std::type_index type) {
return proto::VarType_Type_SELECTED_ROWS;
} else if (type.hash_code() == typeid(ReaderHolder).hash_code()) {
return proto::VarType_Type_READER;
} else if (type.hash_code() == typeid(ChannelHolder).hash_code()) {
return proto::VarType_Type_CHANNEL;
} else {
PADDLE_THROW("ToVarType:Unsupported type %s", type.name());
}
Expand All @@ -57,6 +60,9 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
case proto::VarType_Type_READER:
visitor(var.Get<ReaderHolder>());
return;
case proto::VarType_Type_CHANNEL:
visitor(var.Get<ChannelHolder>());
return;
default:
PADDLE_THROW("Not supported visit type, %d", ToVarType(var.Type()));
}
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pybind/protobuf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ void BindVarDsec(py::module &m) {
.def("set_shapes", &VarDesc::SetShapes)
.def("set_dtype", &VarDesc::SetDataType)
.def("set_dtypes", &VarDesc::SetDataTypes)
.def("set_capacity", &VarDesc::SetCapacity)
.def("shape", &VarDesc::GetShape, py::return_value_policy::reference)
.def("shapes", &VarDesc::GetShapes, py::return_value_policy::reference)
.def("dtype", &VarDesc::GetDataType, py::return_value_policy::reference)
Expand Down Expand Up @@ -246,6 +247,7 @@ void BindVarDsec(py::module &m) {
.value("STEP_SCOPES", proto::VarType::STEP_SCOPES)
.value("LOD_RANK_TABLE", proto::VarType::LOD_RANK_TABLE)
.value("LOD_TENSOR_ARRAY", proto::VarType::LOD_TENSOR_ARRAY)
.value("CHANNEL", proto::VarType::CHANNEL)
.value("PLACE_LIST", proto::VarType::PLACE_LIST)
.value("READER", proto::VarType::READER)
.value("NCCL_COM", proto::VarType::NCCL_COM);
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#include <mutex> // for call_once
#include <unordered_map>
#include "paddle/fluid/framework/backward.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/framework.pb.h"
Expand Down

0 comments on commit 77ee8fb

Please sign in to comment.