Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gradient compression support #225

Merged
merged 386 commits into from
Aug 13, 2020

Conversation

jasperzhong
Copy link
Contributor

@jasperzhong jasperzhong commented Mar 19, 2020

Motivation

Currently BytePS does not fully support gradient compression. The compression it supports lies in each plugin in Python. Such design may ease the difficulty of the implementation but leads to major inabilities for more aggressive compression. This is because NCCL only supports limited reduction operations such as Sum, Prod etc but these operations are meaningless for the compressed data which have been highly bit-wisely packed. For example, for signSGD, one of the most popular methods for gradient compression due to its simplicity and effectiveness, each bit represents a signbit of an element in the original data tensor, making reduction operations like summation totally meaningless. But reduction is necessary for multi-GPU devices.

Another problem is that compared to inter-node communication, intra-node communication is not the bottleneck. Furthermore, too much compression at first will lose much information, which may cause low accuracy. So there is no need to make too radical compression before running into BytePS core in worker nodes.

Therefore, changes need to be made.

Design Overview

In light of the problems mentioned above, we propose two-level gradient compression:

  1. intra-node: This is just an alias for the current implementation, named after its communication property. Transform FP32 tensors into FP16 on each GPU, reduce them across multi-GPUs via NCCL, and copy them to the CPU buffer waiting for next-level compression. The purpose of the compression is to reduce intra-node communication overhead introduced by multi-GPUs. Since intra-node communication is very fast, especially with NCCL, only mild compression methods will be applied, most of which is type-conversion. It is framework-specific and will be implemented in each plugin.

  2. inter-node: Usually inter-node communication is a bottleneck, so more drastically gradient compression algorithms will be applied here. This is framework-agnostic and will be implemented in BytePS core.

It is worth mentioning that our design supports all frameworks.

architecture

Interface

Only a few changes to be made for users. Users only have to add a few LOC in the script to specify which compression algorithm to be used and the parameters needed by the algorithm. Take MXNet for example.

compression_params = {
            "compressor": opt.compressor,
            "ef": opt.ef,
            "momentum": opt.compress_momentum,
            "scaling": opt.onebit_scaling,
            "k": opt.k
}

trainer = bps.DistributedTrainer(params, optimizer, optimizer_params, compression_params=compression_params)

Here we prescribe some keys. Users can lookup documentations to determine which key should be used. Here are some common keys.

KEYS DESC
compressor compression algorithms, including onebit / dithering / topk / randomk
k an integer, must be specified when using dithering / topk / randomk
scaling optional, whether to enable scaling for onebit, default is false
ef error-feedback algorithms, e.g. vanilla
momentum momentum algorithms, e.g. nesterov
seed random seed

If the user's input is not correct, it will give a warning and abort.

Implementation

Parameter Data Structure

To offer users a unified interface to use, we have to address the registration problem. parameters vary from different kinds of compression algorithms. For example, topk and randomk algorithms need parameter k to be specified while onebit algorithm may need to input whether to enable scaling flag. Some parameters are optional but others are not. So parameter passing is a challenge.

We address this challenge using string-string dictionary (std::unorded_map<std::string, std::string> for C++ or dict for Python) as our unified data structure to pass parameters. As mentioned above, we prescribe specific strings as keys, so the dictionary will look like:

{"byteps_compressor_type": "topk", "byteps_compressor_k": "3", "byteps_error_feedback_type": "vanilla"}

Python

For MXNet users, the dictionary can be an attribute of ParameterDict. We can filter out those parameters by leveraging the prefix "byteps". For example,

for i, param in enumerate(self._params):
           byteps_declare_tensor("parameter_" + str(i))
           if param.grad_req != 'null':
               byteps_params = dict(
                   filter(lambda attr: attr[0].startswith(
                       "byteps_",), param.__dict__.items())
               )
               byteps_declare_tensor("gradient_" + str(i), **byteps_params)

C++

Using ctypes, we can pass the dictionary conveniently. For example,

extern "C" void byteps_mxnet_declare_tensor(char* name, int num_params,
                                           char** param_keys,
                                           char** param_vals) {
 ...

 std::unordered_map<std::string, std::string> param_dict;
 std::string key, val;
 std::string::size_type pos;
 for (int i = 0; i < num_params; ++i) {
   key = param_keys[i];
   val = param_vals[i];
   param_dict[key] = val;
 }

 ...
}

Compressor - Development API

We want developers to develop their own gradient compression algorithms without fully understanding how BytePS works. What they only need to know is development API. We currently implement some commonly used gradient compression algorithms, but in the future, we hope more novel algorithms will be implemented under our API. We abstract compression algorithms into compressor. The Compressor looks like this:

class Compressor {
 public:
  Compressor(size_t size, DataType dtype)
      : _size(size),
        _dtype(dtype),
        _buf(new byte_t[size]),
        _cpu_reducer(new CpuReducer(nullptr)){};
  virtual ~Compressor() = default;

  virtual tensor_t Compress(tensor_t grad) = 0;

  virtual tensor_t Decompress(tensor_t compressed) = 0;

  virtual void FastUpdateError(tensor_t error, tensor_t corrected,
                               tensor_t compressed) {
    BPS_LOG(FATAL) << "FastUpdateError is not implemented";
  };

  std::unique_ptr<byte_t[]> _buf;

  size_t _size;

  DataType _dtype;

  std::unique_ptr<CpuReducer> _cpu_reducer;
};

In order to make less modifications to BytePS core, we want compressors to be as general as possible. In the best case, the base compressor pointer/reference can represent all kinds of compressors and only need to expose two operations to users: Compress and Decompress. This is quite challenging because there are some optional features for gradient compression, such as error-feedback and momentum. These are two common methods to correct the bias and accelerate the training process respectively. For example, with error-feedback, before being compressed, gradients are first corrected with errors which refer to the information loss during the last compression, and then errors are re-calculated. Therefore, the workflow is different from only using vanilla gradient compression.

In order to support all these features and expose a unified API at the same time, we use the decorator pattern. We regard error-feedback as an additional behavior of compressors. We want a unified API, which means compressors with error-feedback should expose the same method as those without error-feedback. But in that case we have to create a subclass for each compressor, which is too redundant. So the decorator pattern just solves our problem. We create a decorator class named ErrorFeedback to inherit BaseCompressor while at the same time also keeping a member of BaseCompressor. For example,

class ErrorFeedback : public Compressor {
 public:
  ErrorFeedback(size_t size, DataType dtype, std::unique_ptr<Compressor> cptr)
      : Compressor(size, dtype),
        _cptr(std::move(cptr)),
        _error(new byte_t[size]()) {}
  virtual ~ErrorFeedback() = default;

  virtual tensor_t Compress(tensor_t grad) final;

  virtual tensor_t Decompress(tensor_t compressed) final;

 protected:

  virtual void UpdateGradient(tensor_t grad) = 0;

  virtual void UpdateError(tensor_t corrected, tensor_t compressed);

 protected:
  std::unique_ptr<byte_t[]> _error;

 private:
  std::unique_ptr<Compressor> _cptr;
};

And the workflow is implemented in Compress and Decompress. For example,

tensor_t ErrorFeedback::Compress(tensor_t grad) {
  // 1. grad <- grad + error
  UpdateGradient(grad);

  // 2. c <- Compress(grad)
  auto compressed = _cptr->Compress(grad);

  // 3. e <- grad - Decompress(c)
  UpdateError(grad, compressed);

  return compressed;
}

tensor_t ErrorFeedback::Decompress(tensor_t compressed) {
  // directly forward to internal compressor
  return _cptr->Decompress(compressed);
}

Momentum is implemented in the same way. ErrorFeedBack and Momentum are also base classes to inherit. In this way, error-feedback and momentum becomes optional features to be added to any vanilla gradient compression algorithms.

BTW, momentum is not applied to servers.

Exps

CIFAR100

End-to-End Training

We conduct the experiment in distributed training ResNet18_v2 on the CIFAR100 datasets with 4 AWS P3.16xlarge instances, each equipped with 8 V100 GPUs and 25Gbps network. The compression algorithms benchmarked here are also equipped with error-feedback and nesterov momentum. We set k = 1 for topk and k = 8 for randomk. We train it for 200 epochs.

image

image

f888c8d VAl ACC TIME(s)
baseline 0.713799 703.1527987500002
onebit 0.705601 629.4210848750001
randomk 0.6991 501.99770550000005
topk 0.704202 507.90769437499966

The results show that compression can reduce up to 28.6% end-to-end training time without accuracy loss.

Slow Network

Gradient compression is more beneficial in slower network. Therefore we limit the network bandwidth to 100Mbps (both downlink and uplink) and keep all other settings not changed. The results show that we can achieve up to 6x reduciton in training time.

image

b382f99 TIME(s)
baseline 518.321322125
onebit 195.236724875
randomk 89.672168625
topk 83.9287285

IMAGENET

To save time, we only tested 1bit algorithm. Topk and randomk are not guaranteed to converge on IMAGENET.

Workload Breakdown

In this experiment, we measure the workload breakdown into computation and communication. We use 8 Amazon EC2 p3.2xlarge instances, each of which is shipped with one Nvidia V100 GPU and 10Gbps Ethernet. We train two CNN models: Resnet-50_v2 and VGG-16. We first measure the computation time by collecting the elapsed time of running 50 iterations (t0) on one node. Then we measure the total training time for running 50 iterations (t1) on 8 nodes. Then, we get an estimate of communication time using t1 − t0.

As the figure shows, dist-EF-SGDM can reduce communication to varying degrees. For ResNet50_v2, the drop is trivial (17.6% decrease), mainly due to the smaller model size. In contrast, a remarkable decline (73.2% decrease) occurs using dist-EF-SGDM for VGG-16, since VGG-16 has larger model size (528M).

[ResNet50_v2]
image

[VGG-16]
image

Scaling Efficiency

We also measure scaling efficiency when the number of nodes varies from 1 to 8. We follow the same setup as in the above experiment. The figure shows that gradient compression improves the scaling efficiency. The efficiency gain in gradient compression is much higher for VGG-16 than ResNet-50_v2, since ResNet50_v2 has smaller communication overhead.

[ResNet50_v2]
image

[VGG-16]
image


The above two sub-experiments were conducted 2 months ago. There have been large updates since then. So the results are a little outdated. They are just for reference.

End-to-End Training

Finally, we train ResNet50_v2 and VGG-16 end-to-end to measure total reduction in training time. For such large batch training, warmup and linear scaling learning rate
are used to avoid generalization gap. We set the number of warmup epochs to 5. We also leverage cosine annealing strategy for learning rate decay. For ResNet50_v2 we use 8 AWS EC2 P3.16xlarge instances while for VGG-16, we use 4 AWS EC2 P3.16xlarge.

[ResNet50_v2]
image
image

As the figure shows, we reduce the trianing time by 8.0% without accuracy loss for ResNet50_v2.

6c44049 VAl ACC TIME(h)
sgdm 0.76914465625 2.6505945833029516
dist-ef-sgdm 0.7632242968749999 2.4378090010373263

[VGG-16]
image
image

The above figure shows that our implementation of dist-EF-SGDM reduces the training time for 100 epochs by 39.04% compared to the full-precision SGDM. We note that there is a small gap in accuracy between dist-EF-SGDM and SGDM. We will investigate this problem in the future.

TODO

  • support inter-node compression
  • support intra-node for MXNet
  • support onebit compressor
  • support error-feedback
  • support momentum
  • support other compressors
  • support FP16
  • support PyTorch and Tensorflow

Precautions

  1. To run successfully, ps-lite should change one LOC. see the PR here. Relax Size Check dmlc/ps-lite#168
  2. We only support Gluon for MXNet now. Raw MXNet's API does not support it.
  3. Since gradient compression also has some overhead, this is a trade-off. It is only suitable for some cases, e.g. slow network or large models. In other cases, gradient compression will even harm performance.
  4. Momentum here is the same as the framework's momentum. Why do we have to implement momentum again? This is because for some algorithms like dist-EF-SGDM , momentum should be added first but many frameworks like MXNet exchange gradient first and then add the momentum. So we have to implement momentum inside BytePS. When inside momentum is used, outside momentum should be disabled (set \mu = 0) in the users' scripts.
  5. FP16 is not supported now.

Acknowledgement

Thanks @eric-haibin-lin @szhengac for guidance! They have been giving many valuable suggestions!

@ymjiang
Copy link
Member

ymjiang commented Mar 20, 2020

Hi @vycezhong , thank you very much for your contribution and the detailed documentation! We will start to review the new features soon. Eventually, we want to make sure the original functionality does not break, and the compression benchmarks work as expected.

@jasperzhong jasperzhong force-pushed the gradient_compression branch 3 times, most recently from 89eb7b0 to a77a9d6 Compare April 29, 2020 07:54
@jasperzhong jasperzhong force-pushed the gradient_compression branch 2 times, most recently from 721b4ca to b2acc91 Compare May 16, 2020 12:42
@jasperzhong
Copy link
Contributor Author

jasperzhong commented Jun 20, 2020

Hello everyone, we have summaried our recent updates for three months (Mar.19 - Jun.20). https://docs.google.com/presentation/d/1Dt1Sh2ixVF8Or_Q3lzUM81F4Thj5LT8Xw6QjU1e6iwQ/edit?usp=sharing

If you have time, I think you can start to review the code now. Thank you in advance!

jasperzhong and others added 6 commits July 31, 2020 02:16
* 1bit: not need to do wd mom for uncompressed gradients

* 1bit: fix typo

* 1bit: normal weight decay

* 1bit: update

* 1bit: update

* misc: fix typo
* test: update mxnet

* test: launch tasks entirely in python

* test: auto clean temp files

* test: update

* test: update

* test: update

* test: update

* test: update

* test: update

* test: update

* test: fix natural dithering

* test: update

* test: update
Copy link
Collaborator

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vycezhong could you resolve the conflict?

@jasperzhong
Copy link
Contributor Author

@vycezhong could you resolve the conflict?

done

* mom: nag for uncompressed

* mom: fix typo

* mom: fix typo
* hotfix: revert wdmom refactor

* hotfix: fix typo

* hotfix: fix typo

* hotfix: fix typo
@pleasantrabbit pleasantrabbit merged commit dae88d6 into bytedance:master Aug 13, 2020
pleasantrabbit added a commit that referenced this pull request Aug 14, 2020
gradient compression support

Author: zhongyuchen <izhongyuchen@gmail.com>
Signed-off-by: Yulu Jia <yulu.jia@bytedance.com>
@zhuangwang93
Copy link

Hi,

It will be great for BytePS to support gradient compression because tons of gradient compression algorithms (sparsification, quantization, and low-rank decomposition) are proposed in the machine learning community. My observation is that these algorithms can indeed reduce the communication time, while the compression/decompression overhead suppresses the communication improvement. The good news is that there are solutions to the costly overhead.

I am working on a framework to support gradient compression algorithms and it can reduce the compression/decompression overhead to a negligible level (< 5ms). Based on our preliminary results, the scaling factor* with 8 GPUs is >90% for ResNet50, ResNet101, and VGG16 without NCCL.

If you are interested, we can have a talk about this framework design.

*scaling factor is defined as f=s_n/s_1, where s_n and s_1 are the training speeds with n GPUs and 1 GPU.

@bobzhuyb
Copy link
Member

Hi,

It will be great for BytePS to support gradient compression because tons of gradient compression algorithms (sparsification, quantization, and low-rank decomposition) are proposed in the machine learning community. My observation is that these algorithms can indeed reduce the communication time, while the compression/decompression overhead suppresses the communication improvement. The good news is that there are solutions to the costly overhead.

I am working on a framework to support gradient compression algorithms and it can reduce the compression/decompression overhead to a negligible level (< 5ms). Based on our preliminary results, the scaling factor* with 8 GPUs is >90% for ResNet50, ResNet101, and VGG16 without NCCL.

If you are interested, we can have a talk about this framework design.

*scaling factor is defined as f=s_n/s_1, where s_n and s_1 are the training speeds with n GPUs and 1 GPU.

@eric-haibin-lin @vycezhong

@jasperzhong
Copy link
Contributor Author

Hi,

It will be great for BytePS to support gradient compression because tons of gradient compression algorithms (sparsification, quantization, and low-rank decomposition) are proposed in the machine learning community. My observation is that these algorithms can indeed reduce the communication time, while the compression/decompression overhead suppresses the communication improvement. The good news is that there are solutions to the costly overhead.

I am working on a framework to support gradient compression algorithms and it can reduce the compression/decompression overhead to a negligible level (< 5ms). Based on our preliminary results, the scaling factor* with 8 GPUs is >90% for ResNet50, ResNet101, and VGG16 without NCCL.

If you are interested, we can have a talk about this framework design.

*scaling factor is defined as f=s_n/s_1, where s_n and s_1 are the training speeds with n GPUs and 1 GPU.

Hello, thanks for your interests. compression/decompression overhead is indeed an issue and we have made many efforts to reduce it.

We are still a little confused about your work and we have a few questions.

  1. What is your framework based on, say, TensorFlow, Pytorch or MXNet? Or is it a plugin to BytePS or Horovod?
  2. What's the main language you use in your framework, like Python or C++?
  3. What's your settings in your preliminary results? e.g. Compression/decompression runs on GPU or CPU?

@zhuangwang93
Copy link

Hi,
It will be great for BytePS to support gradient compression because tons of gradient compression algorithms (sparsification, quantization, and low-rank decomposition) are proposed in the machine learning community. My observation is that these algorithms can indeed reduce the communication time, while the compression/decompression overhead suppresses the communication improvement. The good news is that there are solutions to the costly overhead.
I am working on a framework to support gradient compression algorithms and it can reduce the compression/decompression overhead to a negligible level (< 5ms). Based on our preliminary results, the scaling factor* with 8 GPUs is >90% for ResNet50, ResNet101, and VGG16 without NCCL.
If you are interested, we can have a talk about this framework design.
*scaling factor is defined as f=s_n/s_1, where s_n and s_1 are the training speeds with n GPUs and 1 GPU.

Hello, thanks for your interests. compression/decompression overhead is indeed an issue and we have made many efforts to reduce it.

We are still a little confused about your work and we have a few questions.

  1. What is your framework based on, say, TensorFlow, Pytorch or MXNet? Or is it a plugin to BytePS or Horovod?
  2. What's the main language you use in your framework, like Python or C++?
  3. What's your settings in your preliminary results? e.g. Compression/decompression runs on GPU or CPU?

Thanks for your reply.

  1. Our framework is built upon Horovod and now can support PyTorch and will support TensorFlow;
  2. The main language we are using is Python. We also provide some CUDA extensions for PyTorch to speed up the compression/decompression operations and modify the Horovod source code (C++) to support a new communication primitive;
  3. The experiments are running on P100 GPUs. Because there are still some bugs to support NCCL, we only test the scalability with MPI. The tested gradient compression algorithms include DGC (1%), QSGD, EFSignSGD, SignSGD, and OneBit.

@GeKeShi
Copy link

GeKeShi commented Apr 21, 2022

Hi,
Will the gradient compression component support PyTorch in the future?
Thank you!

@jasperzhong
Copy link
Contributor Author

Hi, Will the gradient compression component support PyTorch in the future? Thank you!

Sorry for the late reply. For PyTorch, you can refer to this repo.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants