Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul003 committed Mar 21, 2018
1 parent 7442cd1 commit 05ffa1b
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 103 deletions.
14 changes: 7 additions & 7 deletions src/kvstore/kvstore_dist.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ class KVStoreDist : public KVStoreLocal {
// false means not to delete data when SArray is deleted
auto vals = new ps::SArray<char>(data, size * num_bytes, false);
// issue pull
DataHandleMode mode = (gradient_compression_->get_type() != CompressionType::kNone) ?
DataHandleMode::kCompressedPushPull : DataHandleMode::kDefaultPushPull;
RequestType mode = (gradient_compression_->get_type() != CompressionType::kNone) ?
RequestType::kCompressedPushPull : RequestType::kDefaultPushPull;
int cmd = GetCommandType(mode, dtype);
CHECK_NOTNULL(ps_worker_)->ZPull(
pskv.keys, vals, &pskv.lens, cmd, [vals, cb](){ delete vals; cb(); });
Expand Down Expand Up @@ -383,7 +383,7 @@ class KVStoreDist : public KVStoreLocal {
char* data = (char *) small_buf.data().dptr_;
// do push. false means no delete
ps::SArray<char> vals(data, size, false);
int cmd = GetCommandType(DataHandleMode::kCompressedPushPull, dtype);
int cmd = GetCommandType(RequestType::kCompressedPushPull, dtype);
CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens, cmd, [cb]() { cb(); });
};
// acquire locks on both comm_buf and small_buf so that
Expand All @@ -408,7 +408,7 @@ class KVStoreDist : public KVStoreLocal {
char* data = (char *) send_buf.data().dptr_;
// do push. false means no delete
ps::SArray<char> vals(data, size, false);
int cmd = GetCommandType(DataHandleMode::kDefaultPushPull, dtype);
int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype);
CHECK_NOTNULL(ps_worker_)->ZPush(
pskv.keys, vals, pskv.lens,
cmd, [cb]() { cb(); });
Expand Down Expand Up @@ -442,7 +442,7 @@ class KVStoreDist : public KVStoreLocal {
<< pskv.keys << " size: " << size;
}
ps::SArray<char> vals(data, size * num_bytes, false);
int cmd = GetCommandType(DataHandleMode::kRowSparsePushPull, send_buf.dtype());
int cmd = GetCommandType(RequestType::kRowSparsePushPull, send_buf.dtype());
CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens, cmd, [cb]() { cb(); });
};
Engine::Get()->PushAsync(
Expand Down Expand Up @@ -482,7 +482,7 @@ class KVStoreDist : public KVStoreLocal {
<< pskv.keys << " size: " << size;
}
auto vals = new ps::SArray<char>(data, size * num_bytes, false);
int cmd = GetCommandType(DataHandleMode::kRowSparsePushPull, recv_buf.dtype());
int cmd = GetCommandType(RequestType::kRowSparsePushPull, recv_buf.dtype());
// copy indices to recv_buf. this needs to be done before ZPull
// because after pull is done, the callback function returns and locks are released.
// at this point, later functions may access the indices variable while copy happens
Expand Down Expand Up @@ -561,13 +561,13 @@ class KVStoreDist : public KVStoreLocal {
* Populates both push and pull pskv on first call
*/
inline PSKV& EncodeCompressedKey(int key, size_t original_size, bool is_push, int num_bytes) {

auto krs = ps::Postoffice::Get()->GetServerKeyRanges();
int num_servers = krs.size();
CHECK_GT(num_servers, 0);

// represents size of data to be sent
size_t compr_size = gradient_compression_->GetCompressedSize(original_size);

mu_.lock();
PSKV& pskv = (is_push) ? compr_ps_kv_[key].push : compr_ps_kv_[key].pull;
mu_.unlock();
Expand Down
198 changes: 102 additions & 96 deletions src/kvstore/kvstore_dist_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,37 +44,43 @@ enum class CommandType {
kController, kStopServer, kSyncMode, kSetGradientCompression
};

enum class DataHandleMode {
enum class RequestType {
kDefaultPushPull, kRowSparsePushPull, kCompressedPushPull
};

struct DataHandleType {
DataHandleMode mode;
RequestType valueType;
int dtype;
};

/*!
* Uses Cantor pairing function to generate a unique number given two numbers.
* This number can also be inverted to find the unique pair whose Cantor value is this number.
* Ref: https://en.wikipedia.org/wiki/Pairing_function#Cantor_pairing_function
* @param dtype
* @param mode
* @return
* \param requestType RequestType
* \param dtype integer
* \return Cantor value of arguments
*/
static int GetCommandType(DataHandleMode mode, int d) {
int m = static_cast<int>(mode);
static int GetCommandType(RequestType requestType, int d) {
int m = static_cast<int>(requestType);
return (((m + d) * (m + d + 1)) / 2) + d;
}

static DataHandleType DepairDataHandleType(int z) {
int w = std::floor((std::sqrt(8 * z + 1) - 1)/2);
/*!
* Unpairs Cantor value and finds the two integers used to pair.
* Then returns DataHandleType object with those numbers.
* \param cmd DataHandleCommand generated by GetCommandType function
* \return DataHandleType
*/
static DataHandleType UnpairDataHandleCommand(int cmd) {
int w = std::floor((std::sqrt(8 * cmd + 1) - 1)/2);
int t = ((w * w) + w) / 2;

int y = z - t;
int y = cmd - t;
int x = w - y;
CHECK_GE(x, 0);
CHECK_GE(y, 0);
DataHandleType type;
type.mode = static_cast<DataHandleMode>(x);
type.valueType = static_cast<RequestType>(x);
type.dtype = y;
return type;
}
Expand Down Expand Up @@ -180,6 +186,7 @@ class KVStoreDistServer {
struct MergeBuf {
std::vector<ps::KVMeta> request;
NDArray array;
// temp_array is used to cast received values as float32 for computation if required
NDArray temp_array;
};

Expand All @@ -205,15 +212,15 @@ class KVStoreDistServer {
void DataHandleEx(const ps::KVMeta& req_meta,
const ps::KVPairs<char>& req_data,
ps::KVServer<char>* server) {
DataHandleType type = DepairDataHandleType(req_meta.cmd);
switch (type.mode) {
case DataHandleMode::kRowSparsePushPull:
DataHandleType type = UnpairDataHandleCommand(req_meta.cmd);
switch (type.valueType) {
case RequestType::kRowSparsePushPull:
DataHandleRowSparse(type, req_meta, req_data, server);
break;
case DataHandleMode::kCompressedPushPull:
case RequestType::kCompressedPushPull:
DataHandleCompressed(type, req_meta, req_data, server);
break;
case DataHandleMode::kDefaultPushPull:
case RequestType::kDefaultPushPull:
DataHandleDefault(type, req_meta, req_data, server);
break;
}
Expand All @@ -232,8 +239,8 @@ class KVStoreDistServer {
// if no updater, just copy
CopyFromTo(merged->array, stored);
}
// better to cast once and store than for each pull
// we don't need to wait on this because stored wont go out of scope
// better to cast once and store, than copy for each pull
// we don't need to wait on this because unlike recvd, stored wont go out of scope
if (dtype != mshadow::kFloat32) {
auto& stored_dtype = store_[key].arr_dtype;
CopyFromTo(*stored, &stored_dtype, 0);
Expand Down Expand Up @@ -296,7 +303,10 @@ class KVStoreDistServer {
const NDArray& stored = (dtype == mshadow::kFloat32) ? store_[master_key].arr_fp32 :
store_[master_key].arr_dtype;
CHECK(!stored.is_none()) << "init " << master_key << " first";
if (dtype != mshadow::kFloat32) stored.WaitToRead();
// we already waited on arr_fp32 in ApplyUpdates
if (dtype != mshadow::kFloat32) {
stored.WaitToRead();
}
auto shape = stored.shape();
auto unit_len = shape.ProdShape(1, shape.ndim());
int num_bytes = mshadow::mshadow_sizeof(dtype);
Expand Down Expand Up @@ -338,8 +348,7 @@ class KVStoreDistServer {

TBlob recv_blob;
MSHADOW_REAL_TYPE_SWITCH(type.dtype, DType, {
recv_blob = TBlob((DType*)req_data.vals.data(), // NOLINT(*)
dshape, cpu::kDevMask);
recv_blob = TBlob((DType*)req_data.vals.data(), dshape, cpu::kDevMask);
})
NDArray recved = NDArray(recv_blob, 0);
stored = NDArray(kRowSparseStorage, dshape, Context(), false,
Expand Down Expand Up @@ -378,13 +387,54 @@ class KVStoreDistServer {
server->Response(req_meta);
}

const NDArray ConstructReceivedRsp(DataHandleType type,
const int master_key,
const size_t num_rows,
const TShape &stored_shape,
const ps::KVPairs<char> &req_data) {
int num_bytes= mshadow::mshadow_sizeof(type.dtype);
auto unit_len = req_data.lens[1] / num_bytes;
CHECK_GT(unit_len, 0);
// indices
std::vector<int64_t> indices(num_rows);
DecodeRowIds(req_data.keys, indices.data(), master_key, num_rows);
// data
TBlob idx_blob(indices.data(), mshadow::Shape1(num_rows), cpu::kDevMask);
size_t ds[] = {(size_t) num_rows, (size_t) unit_len};
TShape dshape(ds, ds + 2);
TBlob recv_blob;
MSHADOW_REAL_TYPE_SWITCH(type.dtype, DType, {
recv_blob = TBlob((DType*)req_data.vals.data(), dshape, cpu::kDevMask);
})
// row_sparse NDArray
return NDArray(kRowSparseStorage, stored_shape, recv_blob, {idx_blob}, 0);
}

const NDArray& GetReceivedAsFloat(DataHandleType type,
NDArrayStorageType storageType,
const NDArray &recvd,
MergeBuf *merged,
const TShape &stored_shape) {
if (type.dtype == mshadow::kFloat32) {
return recvd;
} else {
if (merged->temp_array.is_none()) {
if (storageType == kRowSparseStorage) {
merged->temp_array = NDArray(kRowSparseStorage, stored_shape, Context());
} else {
merged->temp_array = NDArray(stored_shape, Context(), false, mshadow::kFloat32);
}
}
CopyFromTo(recvd, merged->temp_array);
return merged->temp_array;
}
}

void DataHandleRowSparse(DataHandleType type, const ps::KVMeta& req_meta,
const ps::KVPairs<char>& req_data,
ps::KVServer<char>* server) {
const ps::KVPairs<char>& req_data, ps::KVServer<char>* server) {
int master_key = DecodeKey(req_data.keys[0]);
auto num_rows = req_data.keys.size() - 1;
auto& stored = store_[master_key].arr_fp32;
auto& stored_dtype = store_[master_key].arr_dtype;
if (req_meta.push) {
CHECK_GT(req_data.lens.size(), 0) << "req_data.lens cannot be empty";
CHECK_EQ(req_data.lens[0], 0);
Expand All @@ -395,55 +445,34 @@ class KVStoreDistServer {
InitRowSparseStored(type, master_key, num_rows, req_meta, req_data, server);
return;
}
auto& merged = merge_buf_[master_key];
auto& stored_shape = stored.shape();
// synced push
if (sync_mode_) {
if (log_verbose_) LOG(INFO) << "sync push: " << master_key << " " << req_data.keys;
auto& merged = merge_buf_[master_key];
if (merged.array.is_none()) {
merged.array = NDArray(kRowSparseStorage, stored.shape(), Context());
merged.temp_array = NDArray(kRowSparseStorage, stored.shape(), Context());
merged.array = NDArray(kRowSparseStorage, stored_shape, Context());
merged.temp_array = NDArray(kRowSparseStorage, stored_shape, Context());
}
if (num_rows == 0) {
// reset to zeros
if (merged.request.size() == 0) {
merged.array = NDArray(kRowSparseStorage, stored.shape(), Context());
if (merged.request.empty()) {
merged.array = NDArray(kRowSparseStorage, stored_shape, Context());
} else {
// nothing to aggregate
}
merged.request.push_back(req_meta);
ApplyUpdates(master_key, type.dtype, &merged, &stored, server);
return;
} else {
int num_bytes= mshadow::mshadow_sizeof(type.dtype);
auto unit_len = req_data.lens[1] / num_bytes;
CHECK_GT(unit_len, 0);
// indices
std::vector<int64_t> indices(num_rows);
DecodeRowIds(req_data.keys, indices.data(), master_key, num_rows);

// data
TBlob idx_blob(indices.data(), mshadow::Shape1(num_rows), cpu::kDevMask);
size_t ds[] = {(size_t) num_rows, (size_t) unit_len};
TShape dshape(ds, ds + 2);
TBlob recv_blob;
MSHADOW_REAL_TYPE_SWITCH(type.dtype, DType, {
recv_blob = TBlob((DType*)req_data.vals.data(), // NOLINT(*)
dshape, cpu::kDevMask);
})

// row_sparse NDArray
NDArray recved(kRowSparseStorage, stored.shape(), recv_blob, {idx_blob}, 0);

if (merged.request.size() == 0) {
const NDArray recved = ConstructReceivedRsp(type, master_key, num_rows,
stored_shape, req_data);
if (merged.request.empty()) {
CopyFromTo(recved, &merged.array, 0);
merged.array.WaitToRead();
} else {
if (type.dtype != mshadow::kFloat32) {
CopyFromTo(recved, merged.temp_array);
AccumulateRowSparseGrads(merged.temp_array, &merged);
} else {
AccumulateRowSparseGrads(recved, &merged);
}
const NDArray& recved_realt = GetReceivedAsFloat(type, kRowSparseStorage,
recved, &merged, stored_shape);
AccumulateRowSparseGrads(recved_realt, &merged);
}
merged.request.push_back(req_meta);
ApplyUpdates(master_key, type.dtype, &merged, &stored, server);
Expand All @@ -455,31 +484,10 @@ class KVStoreDistServer {
server->Response(req_meta);
return;
}
auto& merged = merge_buf_[master_key];
auto unit_len = req_data.lens[1];
CHECK_GT(unit_len, 0);
// indices
std::vector<int64_t> indices(num_rows);
DecodeRowIds(req_data.keys, indices.data(), master_key, num_rows);
TBlob idx_blob(indices.data(), mshadow::Shape1(num_rows), cpu::kDevMask);
size_t ds[] = {(size_t) num_rows, (size_t) unit_len};
TShape dshape(ds, ds + 2);
TBlob recv_blob;
MSHADOW_REAL_TYPE_SWITCH(type.dtype, DType, {
recv_blob = TBlob((DType*)req_data.vals.data(), // NOLINT(*)
dshape, cpu::kDevMask);
})
NDArray recved(kRowSparseStorage, stored.shape(), recv_blob, {idx_blob}, 0);
NDArray recved_realt;
if (type.dtype == mshadow::kFloat32) {
recved_realt = recved;
} else {
if (merged.temp_array.is_none()) {
merged.temp_array = NDArray(kRowSparseStorage, stored.shape(), Context());
}
CopyFromTo(recved, merged.temp_array);
recved_realt = merged.temp_array;
}
const NDArray recved = ConstructReceivedRsp(type, master_key, num_rows,
stored_shape, req_data);
const NDArray& recved_realt = GetReceivedAsFloat(type, kRowSparseStorage,
recved, &merged, stored_shape);
exec_.Exec([this, master_key, &recved_realt, &stored](){
CHECK(updater_);
updater_(master_key, recved_realt, &stored);
Expand Down Expand Up @@ -606,11 +614,9 @@ class KVStoreDistServer {
NDArray recved = NDArray(recv_blob, 0);
if (stored.is_none()) {
// initialization
// stored is real_t
stored = NDArray(dshape, Context(), false, mshadow::DataType<real_t>::kFlag);
stored = NDArray(dshape, Context(), false, mshadow::kFloat32);
if (type.dtype != mshadow::kFloat32) {
stored_dtype = NDArray(dshape, Context(), false, type.dtype);
// no need to wait on stored_dtype because stored will be in scope
}
CopyFromTo(recved, &stored, 0);
if (type.dtype != mshadow::kFloat32) {
Expand All @@ -622,26 +628,26 @@ class KVStoreDistServer {
// synced push
auto& merged = merge_buf_[key];
if (merged.array.is_none()) {
merged.array = NDArray(dshape, Context(), false, mshadow::DataType<real_t>::kFlag);
merged.temp_array = NDArray(dshape, Context(), false, mshadow::DataType<real_t>::kFlag);
merged.array = NDArray(dshape, Context(), false, mshadow::kFloat32);
merged.temp_array = NDArray(dshape, Context(), false, mshadow::kFloat32);
}
if (merged.request.size() == 0) {
CopyFromTo(recved, merged.array);
} else {
if (type.dtype == mshadow::DataType<real_t>::kFlag) {
merged.array += recved;
} else {
CopyFromTo(recved, merged.temp_array);
merged.array += merged.temp_array;
}
const NDArray& recved_float = GetReceivedAsFloat(type, kDefaultStorage,
recved, &merged, dshape);
merged.array += recved_float;
}
merged.request.push_back(req_meta);
ApplyUpdates(key, type.dtype, &merged, &stored, server);
} else {
// async push
exec_.Exec([this, key, &recved, &stored](){
auto& merged = merge_buf_[key];
const NDArray& recved_float = GetReceivedAsFloat(type, kDefaultStorage,
recved, &merged, dshape);
exec_.Exec([this, key, &recved_float, &stored](){
CHECK(updater_);
updater_(key, recved, &stored);
updater_(key, recved_float, &stored);
});
server->Response(req_meta);
if (type.dtype != mshadow::kFloat32) {
Expand Down

0 comments on commit 05ffa1b

Please sign in to comment.