Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exposing Channel to be used as a Variable and integrating with Fluid #8486

Merged
merged 35 commits into from
Feb 23, 2018
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
0db0632
Adding set_capacity method support
Feb 20, 2018
3451c4e
Adding Python for make_channel
Feb 20, 2018
fcd086a
Updating notest_concurrency
Feb 20, 2018
25674f0
Write python for make_channel method
Feb 20, 2018
a393b4e
Write python for make_channel method
Feb 20, 2018
2f612da
Fix make_channel and test
Feb 20, 2018
68d5977
Placeholder ops for channel send, recv and close
Feb 20, 2018
e73d30f
Adding ToTypeIndex method to var_type.h
Feb 21, 2018
267c3ad
Add var_type.h to channel:
Feb 21, 2018
b499d8d
Added POD_Type to the method
Feb 21, 2018
f6803e7
Add CHANNEL to executor
Feb 21, 2018
53d82aa
Updated get and set DataType to accomodate Channels
Feb 21, 2018
029337f
Updating get and set to incorporate channels
Feb 21, 2018
6a100a7
Adding CHANNEL as supported VarType in protobuf
Feb 21, 2018
64cc34d
Removing unecessary import
Feb 21, 2018
d87a65d
Fixing VarDesc to adapt to Channel as VarType
Feb 21, 2018
f3d6a2d
Add channel.h to executor
Feb 21, 2018
20817bd
Remove innclude from channel
Feb 21, 2018
52a96aa
Updated var_type to support Channel as var type
Feb 21, 2018
6a01d3e
Adding get_channel to pybind
Feb 22, 2018
4022dcc
Added ChannelHolder
Feb 22, 2018
dfea686
Adding make_channel as an op
Feb 22, 2018
1c0f569
Adding ChannelHolder in channel
Feb 22, 2018
9fa0ca9
Merge branch 'channel_cpp' into refine_channel
abhinavarora Feb 22, 2018
210f0be
Merge pull request #1 from abhinavarora/refine_channel
abhinavarora Feb 22, 2018
e99f230
Fixing merge conflict
Feb 22, 2018
f33031b
Fixing typo
Feb 22, 2018
b3a53b1
Merge branch 'channel_cpp' of github.com:abhinavarora/Paddle into cha…
Feb 22, 2018
a79de61
Commenting out operators in concurrency
Feb 22, 2018
255cd0b
Merge branch 'channel_cpp' of https://github.com/abhinavarora/Paddle …
Feb 22, 2018
7d648b8
Removing totypeid right now since we don't need it.
Feb 22, 2018
5206574
Reverting python changes
Feb 22, 2018
150db1b
Fixing typo in framework.py
Feb 22, 2018
d4d493d
Modify comments for ReaderHolder
Feb 22, 2018
6e37b8b
Merge remote-tracking branch 'origin/develop' into channel_cpp
Feb 22, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions paddle/fluid/framework/channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License. */
#pragma once

#include <stddef.h> // for size_t
#include <typeindex>
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -51,6 +53,77 @@ void CloseChannel(Channel<T>* ch) {
ch->Close();
}

/*
* The ChannelHolder class serves two main purposes:
* 1. It acts as a unified wrapper for the different kinds of
* channels, i.e. Buffered and Unbuffered channels. This is
* similar to the ReaderHolder class.
* 2. It also helps us in TypeHiding. This is similar to the
* PlaceHolder implementations in variable.h and tensor.h.
*/
class ChannelHolder {
public:
template <typename T>
void Reset(size_t buffer_size) {
holder_.reset(new PlaceholderImpl<T>(buffer_size));
}

template <typename T>
bool Send(T* data) {
if (!IsInitialized()) return false;
PADDLE_ENFORCE_EQ(holder_->Type(), std::type_index(typeid(T)));
// Static cast should be safe because we have ensured that types are same
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
return channel != nullptr ? channel->Send(data) : false;
}

template <typename T>
bool Receive(T* data) {
if (!IsInitialized()) return false;
PADDLE_ENFORCE_EQ(holder_->Type(), std::type_index(typeid(T)));
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
return channel != nullptr ? channel->Receive(data) : false;
}

void close() {
if (IsInitialized()) holder_->Close();
}

inline bool IsInitialized() const { return holder_ != nullptr; }

private:
/**
* @note Placeholder hides type T, so it doesn't appear as a template
* parameter of ChannelHolder.
*/
struct Placeholder {
virtual ~Placeholder() {}
virtual const std::type_index Type() const = 0;
virtual void* Ptr() const = 0;
virtual void Close() const = 0;
std::type_info type_;
};

template <typename T>
struct PlaceholderImpl : public Placeholder {
PlaceholderImpl(size_t buffer_size) : type_(std::type_index(typeid(T))) {
channel_.reset(MakeChannel<T>(buffer_size));
}

virtual const std::type_index Type() const { return type_; }
virtual void* Ptr() const { return static_cast<void*>(channel_.get()); }
virtual void Close() {
if (channel_) channel_->Close();
}

std::unique_ptr<Channel<T>*> channel_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks to me that it might be possible to remove the current base class Channel from channel.h, and rename ChannelHolder into Channel, if we write here

std::unique_ptr<
    boost::variant<
        BufferedChannel<T>, 
        UnbufferedChannel<T>>> channel_;

I am not an expert of boost::variant, but I see it is used many places in our codebase. For example:

typedef boost::variant<CUDAPlace, CPUPlace> Place;

Copy link
Collaborator

@wangkuiyi wangkuiyi Feb 22, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I verified that we could use the polymorphism provided by boost::variant to call Send:

#include <iostream>
#include <boost/variant.hpp>

class BufferedChannel {
public:
  bool Send() {
    std::cout << "BufferedChannel::Send\n";
    return false;
  }
};

class UnbufferedChannel {
public:
  bool Send() {
    std::cout << "UnbufferedChannel::Send\n";
    return false;
  }
};

class SendVisitor : public boost::static_visitor<bool> {
public:
  template <typename T>
  bool operator()(T& t) const {
    return t.Send();
  }
};

int main() {
  boost::variant<BufferedChannel, UnbufferedChannel> ch;
  BufferedChannel bch;
  ch = bch;
  boost::apply_visitor(SendVisitor(), ch);
  return 0;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @wangkuiyi . This looks interesting. I'll look into this Visitor pattern polymorphism.

Copy link
Collaborator

@wangkuiyi wangkuiyi Feb 22, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extend the above example to work with unique_ptr:

#include <iostream>
#include <boost/variant.hpp>

class BufferedChannel {
public:
  bool Send() {
    std::cout << "BufferedChannel::Send\n";
    return false;
  }
};

class UnbufferedChannel {
public:
  bool Send() {
    std::cout << "UnbufferedChannel::Send\n";
    return false;
  }
};

class SendVisitor : public boost::static_visitor<bool> {
public:
  template <typename T>
  bool operator()(T* t) const {
    return t->Send();
  }
};

int main() {
  typedef boost::variant<BufferedChannel*, UnbufferedChannel*> ChannelVariant;
  std::unique_ptr<ChannelVariant> ch;
  ch.reset(new ChannelVariant(new BufferedChannel));
  boost::apply_visitor(SendVisitor(), *ch.get());
  return 0;
}

const std::type_index type_;
};

// Pointer to a PlaceholderImpl object
std::unique_ptr<Placeholder> holder_;
};

} // namespace framework
} // namespace paddle

Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#include <set>

#include "gflags/gflags.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/lod_rank_table.h"
Expand Down Expand Up @@ -55,13 +56,15 @@ static void CreateTensor(Variable* var, proto::VarType::Type var_type) {
var->GetMutable<platform::PlaceList>();
} else if (var_type == proto::VarType::READER) {
var->GetMutable<ReaderHolder>();
} else if (var_type == proto::VarType::CHANNEL) {
var->GetMutable<ChannelHolder>();
} else if (var_type == proto::VarType::NCCL_COM) {
// GetMutable will be called in ncclInit
} else {
PADDLE_THROW(
"Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER, NCCL_COM]",
"LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, NCCL_COM]",
var_type);
}
}
Expand Down
54 changes: 52 additions & 2 deletions paddle/fluid/framework/var_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,13 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
}

void VarDesc::SetDataType(proto::VarType::Type data_type) {
mutable_tensor_desc()->set_data_type(data_type);
switch (desc_.type().type()) {
case proto::VarType::CHANNEL:
mutable_channel_desc()->set_data_type(data_type);
break;
default:
mutable_tensor_desc()->set_data_type(data_type);
}
}

void VarDesc::SetDataTypes(
Expand All @@ -109,7 +115,13 @@ void VarDesc::SetDataTypes(
}

proto::VarType::Type VarDesc::GetDataType() const {
return tensor_desc().data_type();
switch (desc_.type().type()) {
case proto::VarType::CHANNEL:
return channel_desc().data_type();
break;
default:
return tensor_desc().data_type();
}
}

std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
Expand All @@ -122,6 +134,17 @@ std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
return res;
}

void VarDesc::SetCapacity(int64_t capacity) {
switch (desc_.type().type()) {
case proto::VarType::CHANNEL:
desc_.mutable_type()->mutable_channel()->set_capacity(capacity);
break;
default:
PADDLE_THROW("Setting 'capacity' is not supported by the type of var %s.",
this->Name());
}
}

void VarDesc::SetLoDLevel(int32_t lod_level) {
switch (desc_.type().type()) {
case proto::VarType::LOD_TENSOR:
Expand Down Expand Up @@ -191,6 +214,19 @@ std::vector<int32_t> VarDesc::GetLoDLevels() const {
}
}

const proto::VarType::ChannelDesc &VarDesc::channel_desc() const {
PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
switch (desc_.type().type()) {
case proto::VarType::CHANNEL:
return desc_.type().channel();
default:
PADDLE_THROW(
"Getting 'channel_desc' is not supported by the type of var %s.",
this->Name());
}
}

const proto::VarType::TensorDesc &VarDesc::tensor_desc() const {
PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
Expand Down Expand Up @@ -226,6 +262,20 @@ std::vector<proto::VarType::TensorDesc> VarDesc::tensor_descs() const {
}
}

proto::VarType::ChannelDesc *VarDesc::mutable_channel_desc() {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
switch (desc_.type().type()) {
case proto::VarType::CHANNEL:
return desc_.mutable_type()->mutable_channel();
default:
PADDLE_THROW(
"Getting 'mutable_channel_desc' is not supported by the type of var "
"%s.",
this->Name());
}
}

proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/var_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class VarDesc {
void SetDataTypes(
const std::vector<proto::VarType::Type> &multiple_data_type);

void SetCapacity(int64_t capacity);

proto::VarType::Type GetDataType() const;

std::vector<proto::VarType::Type> GetDataTypes() const;
Expand All @@ -106,8 +108,10 @@ class VarDesc {
void SetPersistable(bool persistable) { desc_.set_persistable(persistable); }

private:
const proto::VarType::ChannelDesc &channel_desc() const;
const proto::VarType::TensorDesc &tensor_desc() const;
std::vector<proto::VarType::TensorDesc> tensor_descs() const;
proto::VarType::ChannelDesc *mutable_channel_desc();
proto::VarType::TensorDesc *mutable_tensor_desc();
std::vector<proto::VarType::TensorDesc *> mutable_tensor_descs();

Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/framework/var_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h"
Expand All @@ -34,6 +35,8 @@ inline proto::VarType::Type ToVarType(std::type_index type) {
return proto::VarType_Type_SELECTED_ROWS;
} else if (type.hash_code() == typeid(ReaderHolder).hash_code()) {
return proto::VarType_Type_READER;
} else if (type.hash_code() == typeid(ChannelHolder).hash_code()) {
return proto::VarType_Type_CHANNEL;
} else {
PADDLE_THROW("ToVarType:Unsupported type %s", type.name());
}
Expand All @@ -57,6 +60,9 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
case proto::VarType_Type_READER:
visitor(var.Get<ReaderHolder>());
return;
case proto::VarType_Type_CHANNEL:
visitor(var.Get<ChannelHolder>());
return;
default:
PADDLE_THROW("Not supported visit type, %d", ToVarType(var.Type()));
}
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pybind/protobuf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ void BindVarDsec(py::module &m) {
.def("set_shapes", &VarDesc::SetShapes)
.def("set_dtype", &VarDesc::SetDataType)
.def("set_dtypes", &VarDesc::SetDataTypes)
.def("set_capacity", &VarDesc::SetCapacity)
.def("shape", &VarDesc::GetShape, py::return_value_policy::reference)
.def("shapes", &VarDesc::GetShapes, py::return_value_policy::reference)
.def("dtype", &VarDesc::GetDataType, py::return_value_policy::reference)
Expand Down Expand Up @@ -238,6 +239,7 @@ void BindVarDsec(py::module &m) {
.value("STEP_SCOPES", proto::VarType::STEP_SCOPES)
.value("LOD_RANK_TABLE", proto::VarType::LOD_RANK_TABLE)
.value("LOD_TENSOR_ARRAY", proto::VarType::LOD_TENSOR_ARRAY)
.value("CHANNEL", proto::VarType::CHANNEL)
.value("PLACE_LIST", proto::VarType::PLACE_LIST)
.value("READER", proto::VarType::READER)
.value("NCCL_COM", proto::VarType::NCCL_COM);
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#include <mutex> // for call_once
#include <unordered_map>
#include "paddle/fluid/framework/backward.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/framework.pb.h"
Expand Down