Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: initial commit #151

Merged
merged 32 commits into from
Nov 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
f3d7dc6
server: first commit
bobzhuyb Nov 15, 2019
5de3a9c
server: quick fix
bobzhuyb Nov 15, 2019
4663b55
server: fix building
ymjiang Nov 15, 2019
d76635a
server: fix launcher
bobzhuyb Nov 15, 2019
c9384b2
server: add new server logic
ymjiang Nov 17, 2019
7b8f563
common: fix compile warning
ymjiang Nov 17, 2019
c8ab396
server: attempt to fix rdma compile
ymjiang Nov 17, 2019
ffa5c5b
server: add init_global_env()
ymjiang Nov 17, 2019
b1f7ae8
server: improve the log
ymjiang Nov 17, 2019
a4031ff
server: improve engine load balance
ymjiang Nov 17, 2019
3893011
server: use omp to speed up memcpy (#152)
bobzhuyb Nov 17, 2019
53eda07
server: remove recvmap_mu_
ymjiang Nov 17, 2019
3404d61
server: avoid global mutex
ymjiang Nov 17, 2019
a51a578
server: remove recv->merged copy
ymjiang Nov 18, 2019
2154910
server: attempt to fix correctness
ymjiang Nov 18, 2019
8a77d73
server: PriorityQueue for server (#153)
bobzhuyb Nov 18, 2019
ce0dffd
server: fix disable schedule
ymjiang Nov 18, 2019
c7b1309
server: fix heap sorting correctness
ymjiang Nov 19, 2019
2640d96
server: another fix of heap correctness
ymjiang Nov 19, 2019
bf3fc3e
3rdparty: update pslite
ymjiang Nov 19, 2019
c1487f3
server: improve priority compare
ymjiang Nov 19, 2019
da1328d
common: disable numa control when not summing locally (#157)
bobzhuyb Nov 20, 2019
ac8deea
docker: update dockerfile for new server
bobzhuyb Nov 28, 2019
17780ff
docker: clean up dockerfiles
bobzhuyb Nov 28, 2019
6c1ff86
docker: refactor and test all (#169)
ymjiang Nov 29, 2019
0bfa2d4
docs: clean examples and update tutorial
ymjiang Nov 29, 2019
39ba9e2
docker: quick fix
ymjiang Nov 29, 2019
94324cd
merge master into this branch
ymjiang Nov 29, 2019
9482a33
docs: fix tutorial
ymjiang Nov 29, 2019
1b88762
docs: update performance
ymjiang Nov 29, 2019
065ce08
docs: update readme
ymjiang Nov 29, 2019
c0bbdee
docs: improve legend in readme
ymjiang Nov 29, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/ps-lite
23 changes: 9 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,27 @@

BytePS is a high performance and general distributed training framework. It supports TensorFlow, Keras, PyTorch, and MXNet, and can run on either TCP or RDMA network.

BytePS outperforms existing open-sourced distributed training frameworks by a large margin. For example, on a popular public cloud and with the same number of GPUs, BytePS can *double the training speed* (see below), compared with [Horovod](https://github.com/horovod/horovod)+[NCCL](https://github.com/NVIDIA/nccl).
BytePS outperforms existing open-sourced distributed training frameworks by a large margin. For example, on BERT-large training, BytePS can achieve ~90% scaling efficiency with 256 GPUs (see below), which is much higher than [Horovod](https://github.com/horovod/horovod)+[NCCL](https://github.com/NVIDIA/nccl).

## News

- Use [the ssh launcher](launcher/) to launch your distributed jobs
- Asynchronous training support for
[PyTorch](https://github.com/bytedance/byteps/pull/121),
[TensorFlow](https://github.com/bytedance/byteps/pull/122),
[MXNet](https://github.com/bytedance/byteps/pull/114)
- Find your training stragglers using [server timeline](docs/timeline.md)
- [Improved key distribution strategy for better load-balancing](https://github.com/bytedance/byteps/pull/116)
- [New Server](https://github.com/bytedance/byteps/pull/151): We improve the server performance by a large margin, and it is now independent of MXNet KVStore. Try our [new docker images](docker/).
- Use [the ssh launcher](launcher/) to launch your distributed jobs
- [Improved key distribution strategy for better load-balancing](https://github.com/bytedance/byteps/pull/116)
- [Improved RDMA robustness](https://github.com/bytedance/byteps/pull/91)

## Performance

For demonstration, we test two models: VGG16 (communication-intensive) and Resnet50 (computation-intensive). Both models are trained using fp32.
We show our experiment on BERT-large training, which is based on GluonNLP toolkit. The model uses mixed precision.

We use Tesla V100 16GB GPUs and set batch size equal to 64 *per GPU*. The machines are in fact VMs on a popular public cloud. Each machine has 8 V100 GPUs with NVLink-enabled. Machines are inter-connected with 20 Gbps TCP/IP network.
We use Tesla V100 32GB GPUs and set batch size equal to 64 per GPU. Each machine has 8 V100 GPUs (32GB memory) with NVLink-enabled. Machines are inter-connected with 100 Gbps RoCEv2 network.

BytePS outperforms Horovod (NCCL) by 44% for Resnet50, and 100% for VGG16.
BytePS achieves ~90% scaling efficiency for BERT-large. The code is available [here](https://github.com/ymjiang/gluon-nlp/tree/bert-byteps/scripts/bert).

<img src="/docs/images/perf_tcp_vgg16.png" width="360" height="220"><img src="/docs/images/perf_tcp_resnet50.png" width="360" height="220">
![BERT-Large](https://user-images.githubusercontent.com/13852819/69874496-1ca43600-12f6-11ea-997b-b023e4c93360.png)

You can reproduce the results using the Dockerfiles and example scripts we provide.

Evaluation on RDMA networks can be found at [performance.md](docs/performance.md).
More evaluation in different scenarios can be found at [performance.md](docs/performance.md).

## Goodbye MPI, Hello Cloud

Expand Down
5 changes: 5 additions & 0 deletions byteps/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
#ifndef BYTEPS_COMMON_H
#define BYTEPS_COMMON_H

#ifndef BYTEPS_BUILDING_SERVER
#include <cuda_runtime.h>
#include <nccl.h>
#endif

#include <atomic>
#include <functional>
#include <memory>
Expand Down Expand Up @@ -217,7 +220,9 @@ enum class RequestType {

int GetCommandType(RequestType requestType, int d);

#ifndef BYTEPS_BUILDING_SERVER
ncclDataType_t getNcclDataType(DataType dtype);
#endif

int getDataTypeLength(int dtype);

Expand Down
36 changes: 34 additions & 2 deletions byteps/common/cpu_reducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,32 @@
// limitations under the License.
// =============================================================================

#ifndef BYTEPS_BUILDING_SERVER
#include "global.h"
#endif

#include "cpu_reducer.h"

namespace byteps {
namespace common {

CpuReducer::CpuReducer(std::shared_ptr<BytePSComm> comm) {

#ifndef BYTEPS_BUILDING_SERVER
std::vector<int> peers;
auto pcie_size = BytePSGlobal::GetPcieSwitchSize();
for (int i = BytePSGlobal::GetLocalRank() % pcie_size;
i < BytePSGlobal::GetLocalSize(); i += pcie_size) {
peers.push_back(i);
}
_comm = std::make_shared<BytePSCommSocket>(comm, std::string("cpu"), peers);
if (comm) {
_comm = std::make_shared<BytePSCommSocket>(comm, std::string("cpu"), peers);
}
else {
_comm = nullptr;
}
#endif

if (getenv("BYTEPS_OMP_THREAD_PER_GPU")) {
_num_threads = atoi(getenv("BYTEPS_OMP_THREAD_PER_GPU"));
} else {
Expand All @@ -34,9 +47,14 @@ CpuReducer::CpuReducer(std::shared_ptr<BytePSComm> comm) {
return;
}

#ifndef BYTEPS_BUILDING_SERVER
bool CpuReducer::isRoot() {
if (!_comm) {
return false;
}
return (_comm->getRoot() == BytePSGlobal::GetLocalRank());
}
#endif

int CpuReducer::sum(void* dst, void* src, size_t len, DataType dtype) {
switch (dtype) {
Expand Down Expand Up @@ -64,7 +82,7 @@ int CpuReducer::sum(void* dst, void* src, size_t len, DataType dtype) {
BPS_CHECK(0) << "Unsupported data type: " << dtype;
}
return 0;
}
}

template <typename T>
int CpuReducer::_sum(T* dst, T* src, size_t len) {
Expand Down Expand Up @@ -190,5 +208,19 @@ int CpuReducer::_sum_float16(void* dst, void* src1, void* src2, size_t len) {
return 0;
}

int CpuReducer::copy(void* dst, void* src, size_t len) {
auto in = (float*)src;
auto out = (float*)dst;
#pragma omp parallel for simd num_threads(_num_threads)
for (size_t i = 0; i < len / 4; ++i) {
out[i] = in[i];
}
if (len % 4) {
std::memcpy(out + len / 4, in + len / 4, len % 4);
}
return 0;
}


} // namespace common
} // namespace byteps
17 changes: 16 additions & 1 deletion byteps/common/cpu_reducer.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,16 @@
#endif

#include <memory>
#include <cstring>
#include "common.h"
#include "communicator.h"
#include "logging.h"

#ifndef BYTEPS_BUILDING_SERVER
#include "communicator.h"
#else
typedef void BytePSComm;
#endif

#include <stdint.h>

namespace byteps {
Expand All @@ -41,8 +47,17 @@ class CpuReducer {

int sum(void* dst, void* src, size_t len, DataType dtype);
int sum(void* dst, void* src1, void* src2, size_t len, DataType dtype);
int copy(void* dst, void* src, size_t len);

#ifndef BYTEPS_BUILDING_SERVER
bool isRoot();
std::shared_ptr<BytePSComm> getComm() { return _comm; }
#endif


DataType GetDataType(int dtype) {
return static_cast<DataType>(dtype);
}

private:
#if __AVX__ && __F16C__
Expand Down
4 changes: 2 additions & 2 deletions byteps/common/global.cc
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ uint64_t BytePSGlobal::Hash_DJB2(uint64_t key) {
auto str = std::to_string(key).c_str();
uint64_t hash = 5381;
int c;
while (c = *str) { // hash(i) = hash(i-1) * 33 ^ str[i]
while ((c = *str)) { // hash(i) = hash(i-1) * 33 ^ str[i]
hash = ((hash << 5) + hash) + c;
str++;
}
Expand All @@ -349,7 +349,7 @@ uint64_t BytePSGlobal::Hash_SDBM(uint64_t key) {
auto str = std::to_string(key).c_str();
uint64_t hash = 0;
int c;
while (c = *str) { // hash(i) = hash(i-1) * 65599 + str[i]
while ((c = *str)) { // hash(i) = hash(i-1) * 65599 + str[i]
hash = c + (hash << 6) + (hash << 16) - hash;
str++;
}
Expand Down
18 changes: 11 additions & 7 deletions byteps/common/shared_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,26 @@ std::vector<void*> BytePSSharedMemory::openPcieSharedMemory(uint64_t key,
for (int i = 0; i < BytePSGlobal::GetPcieSwitchNum(); i++) {
auto prefix = std::string("BytePS_Pcie") + std::to_string(i) + "_Shm_";
if (BytePSGlobal::IsDistributed()) {
if (i <= numa_max_node()) {
numa_set_preferred(i);
if (BytePSGlobal::IsCrossPcieSwitch()) {
if (i <= numa_max_node()) {
numa_set_preferred(i);
r.push_back(openSharedMemory(prefix, key, size));
numa_set_preferred(-1);
} else {
numa_set_preferred(numa_max_node());
r.push_back(openSharedMemory(prefix, key, size));
numa_set_preferred(-1);
}
} else {
numa_set_preferred(numa_max_node());
r.push_back(openSharedMemory(prefix, key, size));
}
r.push_back(openSharedMemory(prefix, key, size));
numa_set_preferred(-1);
} else {
if (BytePSGlobal::IsCrossPcieSwitch()) {
numa_set_interleave_mask(numa_all_nodes_ptr);
r.push_back(openSharedMemory(prefix, key, size));
numa_set_interleave_mask(numa_no_nodes_ptr);
} else {
numa_set_preferred(0);
r.push_back(openSharedMemory(prefix, key, size));
numa_set_preferred(-1);
}
}
}
Expand Down
23 changes: 23 additions & 0 deletions byteps/server/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2019 Bytedance Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import ctypes
import os
from byteps.common import get_ext_suffix

dll_path = os.path.join(os.path.dirname(__file__),
'c_lib' + get_ext_suffix())
SERVER_LIB_CTYPES = ctypes.CDLL(dll_path, ctypes.RTLD_GLOBAL)
SERVER_LIB_CTYPES.byteps_server()
110 changes: 110 additions & 0 deletions byteps/server/queue.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Copyright 2019 Bytedance Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================

#ifndef BYTEPS_SERVER_QUEUE_H
#define BYTEPS_SERVER_QUEUE_H

#include <vector>
#include <mutex>
#include <condition_variable>
#include <memory>
#include <algorithm>

namespace byteps {
namespace server {

/**
* \brief thread-safe queue allowing push and waited pop
*/
class PriorityQueue {
public:
PriorityQueue(bool is_schedule) {
enable_schedule_ = is_schedule;
if (enable_schedule_) {
std::make_heap(queue_.begin(), queue_.end(),
[this](const BytePSEngineMessage& a, const BytePSEngineMessage& b) {
return ComparePriority(a, b);
}
);
}
}
~PriorityQueue() { }

/**
* \brief push an value and sort using heap. threadsafe.
* \param new_value the value
*/
void Push(BytePSEngineMessage new_value) {
mu_.lock();
queue_.push_back(std::move(new_value));
if (enable_schedule_) {
++push_cnt_[new_value.key];
std::push_heap(queue_.begin(), queue_.end(),
[this](const BytePSEngineMessage& a, const BytePSEngineMessage& b) {
return ComparePriority(a, b);
}
);
}
mu_.unlock();
cond_.notify_all();
}

/**
* \brief wait until pop an element from the beginning, threadsafe
* \param value the poped value
*/
void WaitAndPop(BytePSEngineMessage* value) {
std::unique_lock<std::mutex> lk(mu_);
cond_.wait(lk, [this]{return !queue_.empty();});
if (enable_schedule_) {
std::pop_heap(queue_.begin(), queue_.end(),
[this](const BytePSEngineMessage& a, const BytePSEngineMessage& b) {
return ComparePriority(a, b);
}
);
*value = queue_.back();
queue_.pop_back();
} else {
*value = std::move(queue_.front());
queue_.erase(queue_.begin());
}
}

void ClearCounter(uint64_t key) {
if (!enable_schedule_) return;
std::unique_lock<std::mutex> lk(mu_);
push_cnt_[key] = 0;
}

bool ComparePriority(const BytePSEngineMessage& a, const BytePSEngineMessage& b) {
if (push_cnt_[a.key] == push_cnt_[b.key]) {
return (a.id > b.id);
} else {
return (push_cnt_[a.key] > push_cnt_[b.key]);
}
}

private:
mutable std::mutex mu_;
std::vector<BytePSEngineMessage> queue_;
std::condition_variable cond_;
std::unordered_map<uint64_t, uint64_t> push_cnt_;
volatile bool enable_schedule_ = false;
};

} // namespace server
} // namespace byteps

#endif // BYTEPS_SERVER_QUEUE_H
Loading