From 0518252d6431133f94e697df9ac46d1efb1a6a2f Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Fri, 11 Dec 2020 10:05:37 -0800 Subject: [PATCH 01/17] add manual workflow to run tests with precompiled ops --- .github/workflows/pre-compile-ops.yml | 47 +++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 .github/workflows/pre-compile-ops.yml diff --git a/.github/workflows/pre-compile-ops.yml b/.github/workflows/pre-compile-ops.yml new file mode 100644 index 000000000000..4005d4baf2fc --- /dev/null +++ b/.github/workflows/pre-compile-ops.yml @@ -0,0 +1,47 @@ +# This is a basic workflow to help you get started with Actions + +name: Tests-w-precompiled-ops + +# Controls when the action will run. +on: + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +# A workflow run is made up of one or more jobs that can run sequentially or in parallel +jobs: + # This workflow contains a single job called "build" + build: + # The type of runner that the job will run on + runs-on: self-hosted + + # Steps represent a sequence of tasks that will be executed as part of the job + steps: + # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it + - uses: actions/checkout@v2 + + # Runs a single command using the runners shell + - name: environment + run: | + nvidia-smi + which python + python --version + which nvcc + nvcc --version + python -c "import torch; print('torch:', torch.__version__, torch)" + python -c "import torch; print('CUDA available:', torch.cuda.is_available())" + + # Runs a set of commands using the runners shell + - name: Install deepspeed + run: | + DS_BUILD_OPS=1 pip install .[dev] + ds_report + + - name: Formatting checks + run: | + pre-commit run --all-files + + # Runs a set of commands using the runners shell + - name: Unit tests + run: | + if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi + TORCH_EXTENSIONS_DIR=./torch-extensions pytest --durations=0 --forked --verbose -x tests/unit/ From 8a184b6b1ddfe4bf41467c06b693817af80aa530 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Fri, 11 Dec 2020 10:15:33 -0800 Subject: [PATCH 02/17] [build] fix computer capability arch flags, add PTX, handle PTX (#591) * fix arch flags, add PTX * bug fix Co-authored-by: Jeff Rasley --- op_builder/builder.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/op_builder/builder.py b/op_builder/builder.py index f44aee79637a..1f350065b4f6 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -221,7 +221,7 @@ def compute_capability_args(self, cross_compile_archs=None): 1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`. 2. If neither is set default compute capabilities will be used - 3. Under `jit_mode` compute capabilities of all visible cards will be used. + 3. Under `jit_mode` compute capabilities of all visible cards will be used plus PTX Format: @@ -243,6 +243,7 @@ def compute_capability_args(self, cross_compile_archs=None): if cc not in ccs: ccs.append(cc) ccs = sorted(ccs) + ccs[-1] += '+PTX' else: # Cross-compile mode, compile for various architectures # env override takes priority @@ -260,8 +261,10 @@ def compute_capability_args(self, cross_compile_archs=None): args = [] for cc in ccs: - cc = cc.replace('.', '') - args.append(f'-gencode=arch=compute_{cc},code=compute_{cc}') + num = cc[0] + cc[2] + args.append(f'-gencode=arch=compute_{num},code=sm_{num}') + if cc.endswith('+PTX'): + args.append(f'-gencode=arch=compute_{num},code=compute_{num}') return args From 66268bd3377f80e7bf37e48184a6211854b64a9e Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Fri, 11 Dec 2020 12:40:14 -0800 Subject: [PATCH 03/17] add DeepSpeedZeroConfig repr method (#596) Co-authored-by: Jeff Rasley --- deepspeed/runtime/zero/config.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 14bfc937705c..b784f3ffdd6c 100755 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -6,6 +6,7 @@ from deepspeed.runtime.config_utils import get_scalar_param from deepspeed.utils import logger from deepspeed.runtime.zero.constants import * +import json class DeepSpeedZeroConfig(object): @@ -54,6 +55,9 @@ def read_zero_config_deprecated(self, param_dict): def repr(self): return self.__dict__ + def __repr__(self): + return json.dumps(self.__dict__, sort_keys=True, indent=4) + def _initialize(self, zero_config_dict): self.stage = get_scalar_param(zero_config_dict, ZERO_OPTIMIZATION_STAGE, From a4763f5516c0a9bb8f32d2e8a48c618f0fe88e37 Mon Sep 17 00:00:00 2001 From: carefree0910 Date: Sat, 12 Dec 2020 05:52:06 +0800 Subject: [PATCH 04/17] Supported customizing kwargs for lr_scheduler (#584) Co-authored-by: Jeff Rasley --- deepspeed/runtime/engine.py | 8 ++++---- deepspeed/runtime/pipe/engine.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 7431b2c892c4..76ba6af78b76 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -979,7 +979,7 @@ def clip_fp32_gradients(self): torch.nn.utils.clip_grad_norm_(parameters=self.module.parameters(), max_norm=self.gradient_clipping()) - def _take_model_step(self): + def _take_model_step(self, lr_kwargs): if self.gradient_clipping() > 0.0: if not self.fp16_enabled() and not self.amp_enabled(): self.clip_fp32_gradients() @@ -1010,14 +1010,14 @@ def _take_model_step(self): self.skipped_steps += 1 else: if self.lr_scheduler is not None: - self.lr_scheduler.step() + self.lr_scheduler.step(**(lr_kwargs or {})) if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0: self._report_progress(self.global_steps + 1) self.global_steps += 1 self.global_samples += self.train_batch_size() - def step(self): + def step(self, lr_kwargs=None): r"""Execute the weight update step after forward and backward propagation on effective_train_batch. """ @@ -1034,7 +1034,7 @@ def step(self): if self.progressive_layer_drop: self.progressive_layer_drop.update_state(self.global_steps) - self._take_model_step() + self._take_model_step(lr_kwargs) self.tput_timer.stop(report_progress) diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 954774e58912..5c5d896dfc0d 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -940,14 +940,14 @@ def _exec_recv_grads(self, buffer_id): if self.wall_clock_breakdown(): self.timers('pipe_recv_grad').stop() - def _exec_optimizer_step(self): + def _exec_optimizer_step(self, lr_kwargs=None): if self.wall_clock_breakdown(): self.timers('step_microstep').start() self.timers('step').start() self.mem_status('BEFORE STEP', reset_max=True) self._force_grad_boundary = True - self._take_model_step() + self._take_model_step(lr_kwargs) self._force_grad_boundary = False self.mem_status('AFTER STEP') From c5a449f9a370d99cdbd265e696b6b3b2a1bd2565 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Fri, 11 Dec 2020 14:54:45 -0800 Subject: [PATCH 05/17] Update launcher to set local rank environ variable (#597) * Update launch.py * formatting --- deepspeed/launcher/launch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepspeed/launcher/launch.py b/deepspeed/launcher/launch.py index 205aee2d6ac4..74a0530c7f98 100755 --- a/deepspeed/launcher/launch.py +++ b/deepspeed/launcher/launch.py @@ -113,6 +113,7 @@ def main(): # each process's rank dist_rank = global_rank_mapping[local_node][local_rank] current_env["RANK"] = str(dist_rank) + current_env["LOCAL_RANK"] = str(local_rank) # spawn the processes cmd = [ From 9f8e8f38290ecea6b5054df49ec8635b248c18a8 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 14 Dec 2020 14:24:58 -0800 Subject: [PATCH 06/17] implement missing get_last_lr (#595) Co-authored-by: Jeff Rasley --- deepspeed/runtime/lr_schedules.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/deepspeed/runtime/lr_schedules.py b/deepspeed/runtime/lr_schedules.py index 5ec106c28d67..e7e3be1e786b 100755 --- a/deepspeed/runtime/lr_schedules.py +++ b/deepspeed/runtime/lr_schedules.py @@ -381,6 +381,12 @@ def get_lr(self): lr_range_test_min_lr * lr_increase for lr_range_test_min_lr in self.min_lr ] + def get_last_lr(self): + """ Return last computed learning rate by current scheduler. + """ + assert getattr(self, '_last_lr', None) is not None, "need to call step() first" + return self._last_lr + def _update_optimizer(self, group_lrs): for param_group, lr in zip(self.optimizer.param_groups, group_lrs): param_group['lr'] = lr @@ -390,6 +396,7 @@ def step(self, batch_iteration=None): batch_iteration = self.last_batch_iteration + 1 self.last_batch_iteration = batch_iteration self._update_optimizer(self.get_lr()) + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] def state_dict(self): return {'last_batch_iteration': self.last_batch_iteration} @@ -628,12 +635,19 @@ def get_lr(self): return self._get_cycle_lr() return self._get_decay_lr(self.last_batch_iteration - self.total_size) + def get_last_lr(self): + """ Return last computed learning rate by current scheduler. + """ + assert getattr(self, '_last_lr', None) is not None, "need to call step() first" + return self._last_lr + def step(self, batch_iteration=None): if batch_iteration is None: batch_iteration = self.last_batch_iteration + 1 self.last_batch_iteration = batch_iteration for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): param_group['lr'] = lr + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] def state_dict(self): return {'last_batch_iteration': self.last_batch_iteration} @@ -690,12 +704,19 @@ def get_lr(self): self.delta_lrs) ] + def get_last_lr(self): + """ Return last computed learning rate by current scheduler. + """ + assert getattr(self, '_last_lr', None) is not None, "need to call step() first" + return self._last_lr + def step(self, last_batch_iteration=None): if last_batch_iteration is None: last_batch_iteration = self.last_batch_iteration + 1 self.last_batch_iteration = last_batch_iteration for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): param_group['lr'] = lr + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] def state_dict(self): return {'last_batch_iteration': self.last_batch_iteration} From 007466e576b65a5efd2c8195c4f1349b1e225c1b Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Tue, 15 Dec 2020 13:44:32 -0800 Subject: [PATCH 07/17] [doc] xref to hostfile discussion (#604) * [doc] xref to hostfile discussion wasn't clear where to find what was meant by `hostfile` - so adding a link to where it's discussed. * remove whitespace --- docs/_pages/features.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/_pages/features.md b/docs/_pages/features.md index ec0724e11aa4..2074bb3e3b0f 100755 --- a/docs/_pages/features.md +++ b/docs/_pages/features.md @@ -28,7 +28,8 @@ deepspeed --hostfile= \ \ --deepspeed --deepspeed_config ds_config.json ``` -The script `` will execute on the resources specified in ``. +The script `` will execute on the resources specified in +[``](/getting-started/#resource-configuration-multi-node). ## Pipeline Parallelism DeepSpeed provides [pipeline parallelism](/tutorials/pipeline/) for memory- From 6380ee35116dfc7c2a037d48fbd790d8dbabfc1d Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Tue, 15 Dec 2020 15:29:21 -0800 Subject: [PATCH 08/17] Fixes for RTD build errors (#606) Co-authored-by: Shaden Smith --- deepspeed/git_version_info.py | 8 ++++++-- deepspeed/ops/sparse_attention/softmax.py | 4 ++-- docs/code-docs/source/conf.py | 2 +- requirements/requirements-readthedocs.txt | 1 - 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/deepspeed/git_version_info.py b/deepspeed/git_version_info.py index d17948ae41a7..f04982c74f0d 100644 --- a/deepspeed/git_version_info.py +++ b/deepspeed/git_version_info.py @@ -2,8 +2,12 @@ # This is populated by setup.py from .git_version_info_installed import * except ModuleNotFoundError: - # Will be missing from checkouts that haven't been installed (e.g., readthedocs) - version = open('version.txt', 'r').read().strip() + import os + if os.path.isfile('version.txt'): + # Will be missing from checkouts that haven't been installed (e.g., readthedocs) + version = open('version.txt', 'r').read().strip() + else: + version = "0.0.0" git_hash = '[none]' git_branch = '[none]' diff --git a/deepspeed/ops/sparse_attention/softmax.py b/deepspeed/ops/sparse_attention/softmax.py index cd18fbcae71f..a0805ada4bc0 100644 --- a/deepspeed/ops/sparse_attention/softmax.py +++ b/deepspeed/ops/sparse_attention/softmax.py @@ -224,8 +224,8 @@ class Softmax: For more details about sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509 """ - - sparse_softmax = _sparse_softmax.apply + def sparse_softmax(*args, **kwargs): + return _sparse_softmax.apply(*args, **kwargs) def make_lut(self, device): """Generates the sparsity layout used in block-sparse softmax diff --git a/docs/code-docs/source/conf.py b/docs/code-docs/source/conf.py index 167f6427d7b4..eb9a412d8a4a 100644 --- a/docs/code-docs/source/conf.py +++ b/docs/code-docs/source/conf.py @@ -79,4 +79,4 @@ autoclass_content = 'both' -autodoc_mock_imports = ["torch", "apex", "mpi4py", "tensorboardX", "numpy"] +autodoc_mock_imports = ["torch", "apex", "mpi4py", "tensorboardX", "numpy", "cupy"] diff --git a/requirements/requirements-readthedocs.txt b/requirements/requirements-readthedocs.txt index c032a8c9fdad..78620c472c9d 100644 --- a/requirements/requirements-readthedocs.txt +++ b/requirements/requirements-readthedocs.txt @@ -1,2 +1 @@ tqdm -psutil From fd2f970bdfd3eada289b4e19a3adcf2c352a4d8f Mon Sep 17 00:00:00 2001 From: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com> Date: Thu, 17 Dec 2020 10:13:54 -0800 Subject: [PATCH 09/17] Transformer-kernel - supporting any arbitrary sequence-length (#587) Co-authored-by: Jeff Rasley --- DeepSpeedExamples | 2 +- csrc/transformer/ds_transformer_cuda.cpp | 13 ++++++--- csrc/transformer/softmax_kernels.cu | 20 +++++++++----- deepspeed/ops/transformer/transformer.py | 28 ++++++++++++++----- tests/unit/test_cuda_backward.py | 32 +++++++++------------- tests/unit/test_cuda_forward.py | 34 +++++++++--------------- 6 files changed, 72 insertions(+), 57 deletions(-) mode change 100644 => 100755 csrc/transformer/ds_transformer_cuda.cpp diff --git a/DeepSpeedExamples b/DeepSpeedExamples index fa1d1a71c486..abb270641ca8 160000 --- a/DeepSpeedExamples +++ b/DeepSpeedExamples @@ -1 +1 @@ -Subproject commit fa1d1a71c48623db8a091d9cf636a5fe3b8f43c7 +Subproject commit abb270641ca8c33476282bde29916c395a060ae9 diff --git a/csrc/transformer/ds_transformer_cuda.cpp b/csrc/transformer/ds_transformer_cuda.cpp old mode 100644 new mode 100755 index 85ec0418971c..ebd534d04ab3 --- a/csrc/transformer/ds_transformer_cuda.cpp +++ b/csrc/transformer/ds_transformer_cuda.cpp @@ -14,6 +14,8 @@ static std::unordered_map> s_transformer_layers; +const int init_seq_length = 128; + // C++ interface template @@ -591,7 +593,6 @@ int create_transformer_layer(int layer_id, int hidden_dim, int num_heads, int intermediate_size, - int seq_length, float attn_dropout_ratio, float hidden_dropout_ratio, int seed, @@ -604,14 +605,14 @@ int create_transformer_layer(int layer_id, { Context::Instance().SetSeed(seed); Context::Instance().TestGemmFP16( - test_gemm, batch_size, seq_length, num_heads, hidden_dim / num_heads); + test_gemm, batch_size, init_seq_length, num_heads, hidden_dim / num_heads); auto layer = std::make_shared>(layer_id, batch_size, hidden_dim, num_heads, intermediate_size, - seq_length, + init_seq_length, attn_dropout_ratio, hidden_dropout_ratio, pre_or_postLayerNorm, @@ -873,6 +874,12 @@ std::vector ds_transformer_backward(int layer_id, std::shared_ptr> layer = std::static_pointer_cast>(s_transformer_layers[layer_id]); + int seq_len = layer->GetSeqLength(); + if (g_output.size(1) != seq_len) { + seq_len = g_output.size(1); + layer->SetSeqLength(seq_len, bsz); + } + auto grad_input = torch::empty_like(input); auto grad_attn_qkvw = torch::empty_like(attn_qkvw); auto grad_attn_qkvb = torch::empty_like(attn_qkvb); diff --git a/csrc/transformer/softmax_kernels.cu b/csrc/transformer/softmax_kernels.cu index 582da4829f47..be776b0c074d 100644 --- a/csrc/transformer/softmax_kernels.cu +++ b/csrc/transformer/softmax_kernels.cu @@ -80,7 +80,8 @@ __global__ void attn_softmax(float* vals, #endif int iters = warp_num; - if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length); + if (seq_length < iteration_stride) + iters = warp_num / (iteration_stride / max_threads_in_sequence); for (int i = 1; i < iters; i *= 2) { auto temp = g.shfl_xor(max_val, i); @@ -113,7 +114,8 @@ __global__ void attn_softmax(float* vals, #endif int iters = warp_num; - if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length); + if (seq_length < iteration_stride) + iters = warp_num / (iteration_stride / max_threads_in_sequence); for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); } @@ -216,7 +218,8 @@ __global__ void attn_softmax(__half* vals, #endif int iters = warp_num; - if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length); + if (seq_length < iteration_stride) + iters = warp_num / (iteration_stride / max_threads_in_sequence); for (int i = 1; i < iters; i *= 2) { auto temp = g.shfl_xor(max_val, i); @@ -252,7 +255,8 @@ __global__ void attn_softmax(__half* vals, #endif int iters = warp_num; - if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length); + if (seq_length < iteration_stride) + iters = warp_num / (iteration_stride / max_threads_in_sequence); for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); } @@ -339,7 +343,9 @@ void launch_attn_softmax(float* vals, dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / subblock_max_workload * threads) : threads); - + iterations = + (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads + : MAX_THREAD_ITERATIONS); if (sequence_length <= 512) attn_softmax<32, (threads / 128), 128><<>>( vals, attn_mask, heads, seq_length4, iterations); @@ -408,7 +414,9 @@ void launch_attn_softmax<__half>(__half* vals, dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / subblock_max_workload * threads) : threads); - + iterations = + (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads + : MAX_THREAD_ITERATIONS); if (sequence_length <= 512) attn_softmax<32, (threads / 128), 128><<>>( vals, attn_mask, heads, seq_length4, iterations); diff --git a/deepspeed/ops/transformer/transformer.py b/deepspeed/ops/transformer/transformer.py index a91e5ce6f08b..ea4b98848d3c 100755 --- a/deepspeed/ops/transformer/transformer.py +++ b/deepspeed/ops/transformer/transformer.py @@ -18,7 +18,6 @@ class TransformerConfig(): def __init__(self, batch_size, - max_seq_length, hidden_size, intermediate_size, heads, @@ -30,7 +29,6 @@ def __init__(self, self.batch_size = batch_size self.hidden_size = hidden_size self.intermediate_size = intermediate_size - self.max_seq_length = max_seq_length self.heads = heads self.attn_dropout_ratio = attn_dropout_ratio self.hidden_dropout_ratio = hidden_dropout_ratio @@ -92,7 +90,6 @@ class DeepSpeedTransformerConfig(TransformerConfig): """ def __init__(self, batch_size=-1, - max_seq_length=-1, hidden_size=-1, intermediate_size=-1, heads=-1, @@ -112,7 +109,6 @@ def __init__(self, super(DeepSpeedTransformerConfig, self).__init__( batch_size, - max_seq_length, hidden_size, (intermediate_size if intermediate_size > 0 else 4 * hidden_size), heads, @@ -142,7 +138,7 @@ def from_dict(cls, json_object): @classmethod def from_json_file(cls, json_file): - with open(json_file, "r", encoding='utf-8') as reader: + with open(json_file, "r", encoding='utf-16') as reader: text = reader.read() return cls.from_dict(json.loads(text)) @@ -177,6 +173,18 @@ def forward(ctx, cuda_module = stochastic_transformer_cuda_module if config.stochastic_mode else transformer_cuda_module forward_func = cuda_module.forward_fp16 if config.fp16 else cuda_module.forward_fp32 + inp_size = input.size() + if inp_size[1] % 16 != 0: + input = torch.cat((input, + torch.randn((inp_size[0], + (16 - (inp_size[1] % 16)), + inp_size[2]), + device=input.device, + dtype=input.dtype)), + 1) + input_mask = torch.cat((input_mask, torch.ones((inp_size[0], input_mask.shape[1], input_mask.shape[2], \ + (16 - (inp_size[1] % 16))), device=input_mask.device, dtype=input_mask.dtype) * -10000), 3) + (output, inp_norm, qkv_tf, @@ -303,11 +311,17 @@ def forward(ctx, ctx.attn_layer_norm_var = attn_layer_norm_var ctx.layer_norm_var = layer_norm_var + if inp_size[1] % 16 != 0: + output = torch.narrow(output, 1, 0, inp_size[1]) return output @staticmethod def backward(ctx, grad_output): bsz = grad_output.shape[0] + grad_output_shape = grad_output.size() + if grad_output_shape[1] % 16 != 0: + grad_output = torch.cat((grad_output, torch.zeros((bsz, (16 - (grad_output_shape[1] % 16)), \ + grad_output_shape[2]), device=grad_output.device, dtype=grad_output.dtype)), 1) if bsz > ctx.config.batch_size: raise ValueError('grad_output batch size exceeds the limit.') @@ -398,6 +412,9 @@ def backward(ctx, grad_output): norm_w, norm_b) + if grad_output_shape[1] % 16 != 0: + grad_input = torch.narrow(grad_input, 1, 0, grad_output_shape[1]) + return (grad_input, None, None, @@ -501,7 +518,6 @@ def __init__(self, layer_id, config, initial_weights=None, initial_biases=None): self.config.hidden_size, self.config.heads, self.config.intermediate_size, - self.config.max_seq_length, self.config.attn_dropout_ratio, self.config.hidden_dropout_ratio, self.config.seed, diff --git a/tests/unit/test_cuda_backward.py b/tests/unit/test_cuda_backward.py index 317cd7aa33c0..fd3f9887ad42 100755 --- a/tests/unit/test_cuda_backward.py +++ b/tests/unit/test_cuda_backward.py @@ -150,7 +150,7 @@ def create_models(ds_config): hidden_act="gelu", hidden_dropout_prob=ds_config.hidden_dropout_ratio, attention_probs_dropout_prob=ds_config.attn_dropout_ratio, - max_position_embeddings=ds_config.max_seq_length, + max_position_embeddings=512, type_vocab_size=2, initializer_range=ds_config.initializer_range) @@ -210,25 +210,18 @@ def set_seed(seed): torch.manual_seed(seed) -def run_backward(ds_config, atol=1e-2, verbose=False): +def run_backward(ds_config, seq_len, atol=1e-2, verbose=False): set_seed(123) bert_encoder, ds_encoder = create_models(ds_config) # prepare test data kwargs = kwargs_fp16 if ds_config.fp16 else kwargs_fp32 hidden_states = torch.randn(ds_config.batch_size, - ds_config.max_seq_length, + seq_len, ds_config.hidden_size, **kwargs) - input_mask = torch.randn(ds_config.batch_size, - 1, - 1, - ds_config.max_seq_length, - **kwargs) - Y = torch.randn(ds_config.batch_size, - ds_config.max_seq_length, - ds_config.hidden_size, - **kwargs) + input_mask = torch.randn(ds_config.batch_size, 1, 1, seq_len, **kwargs) + Y = torch.randn(ds_config.batch_size, seq_len, ds_config.hidden_size, **kwargs) # run baseline base_results = bert_encoder(hidden_states, @@ -257,12 +250,12 @@ def run_backward(ds_config, atol=1e-2, verbose=False): #test_backward[3-1024-120-16-24-True-True-0.05] @pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol', [ - (3,1024,120,16,24,True,False, 0.05), - (3,1024,120,16,24,True,True, 0.05), - (3,1024,56,16,24,False,False, 0.1), - (3,1024,56,16,24,False,True, 0.2), - (3,128,56,2,24,False,False, 0.1), - (3,128,56,2,24,False,True, 0.2), + (3,1024,119,16,24,True,False, 0.05), + (3,1024,115,16,24,True,True, 0.05), + (1024,128,10,2,2,False,False, 0.1), + (3,1024,52,16,24,False,True, 0.2), + (3,128,51,2,24,False,False, 0.1), + (3,128,54,2,24,False,True, 0.2), ]) # yapf: disable def test_backward(batch_size, hidden_size, @@ -282,7 +275,6 @@ def test_backward(batch_size, ds_config.batch_size = batch_size ds_config.hidden_size = hidden_size ds_config.intermediate_size = hidden_size - ds_config.max_seq_length = seq_len ds_config.heads = heads ds_config.attn_dropout_ratio = 0.0 ds_config.hidden_dropout_ratio = 0.0 @@ -291,7 +283,7 @@ def test_backward(batch_size, ds_config.initializer_range = 0.02 ds_config.fp16 = use_fp16 - run_backward(ds_config, atol=atol) + run_backward(ds_config, seq_len, atol=atol) #@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol', diff --git a/tests/unit/test_cuda_forward.py b/tests/unit/test_cuda_forward.py index 893b66c904bb..88cb90848603 100755 --- a/tests/unit/test_cuda_forward.py +++ b/tests/unit/test_cuda_forward.py @@ -117,7 +117,7 @@ def create_models(ds_config): hidden_act="gelu", hidden_dropout_prob=ds_config.hidden_dropout_ratio, attention_probs_dropout_prob=ds_config.attn_dropout_ratio, - max_position_embeddings=ds_config.max_seq_length, + max_position_embeddings=512, type_vocab_size=2, initializer_range=ds_config.initializer_range, fp16=ds_config.fp16) @@ -186,13 +186,8 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None): # prepare test data kwargs = kwargs_fp16 if ds_config.fp16 else kwargs_fp32 - hidden_states = torch.randn(bsz, - seq_len, #ds_config.max_seq_length, - ds_config.hidden_size, - **kwargs) - input_mask = torch.randn(bsz, 1, 1, - seq_len, #ds_config.max_seq_length, - **kwargs) + hidden_states = torch.randn(bsz, seq_len, ds_config.hidden_size, **kwargs) + input_mask = torch.randn(bsz, 1, 1, seq_len, **kwargs) # run baseline base_results = bert_encoder(hidden_states, @@ -213,25 +208,25 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None): # FP16 test cases can only run on the devices support FP16. @pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16', [ - (8,256,128,4,3,True,False), - (8,256,128,4,3,True,True), - (64,1024,128,16,3,True,False), - (64,1024,128,16,3,True,True), - (8,1024,384,16,3,True,False), + (8,256,53,4,3,True,False), + (8,256,52,4,3,True,True), + (3,1024,51,16,3,True,False), + (3,1024,54,16,3,True,True), + (8,1024,381,16,3,True,False), (8,1024,384,16,3,True,True), (8,1024,384,16,3,True,True), - (8,1024,120,16,3,True,False), + (8,1024,119,16,3,True,False), (8,1024,120,16,3,True,True), - (8,1024,512,16,3,True,False), + (8,1024,509,16,3,True,False), (8,1024,512,16,3,True,True), (64,1024,56,16,3,False,False), - (64,1024,56,16,3,False,True), + (64,1024,53,16,3,False,True), (64,1024,24,16,3,False,False), - (64,1024,24,16,3,False,True), + (64,1024,21,16,3,False,True), (8,1024,384,16,3,False,False), (8,1024,384,16,3,False,True), (8,1024,512,16,3,False,False), - (8,1024,512,16,3,False,True), + (8,1024,511,16,3,False,True), (8,1536,128,24,3,False,False), (8,1536,128,24,3,False,True), (8,2048,128,32,3,False,False), @@ -259,7 +254,6 @@ def test_forward(batch_size, ds_config.layer_id = None ds_config.batch_size = batch_size ds_config.hidden_size = hidden_size - ds_config.max_seq_length = 128 #seq_len ds_config.intermediate_size = 4 * hidden_size ds_config.heads = heads ds_config.attn_dropout_ratio = 0.0 @@ -297,7 +291,6 @@ def test_forward_with_small_bsz(batch_size, ds_config.batch_size = batch_size ds_config.hidden_size = hidden_size ds_config.intermediate_size = 4 * hidden_size - ds_config.max_seq_length = seq_len ds_config.heads = heads ds_config.attn_dropout_ratio = 0.0 ds_config.hidden_dropout_ratio = 0.0 @@ -332,7 +325,6 @@ def test_forward_stochastic(batch_size, ds_config.batch_size = batch_size ds_config.hidden_size = hidden_size ds_config.intermediate_size = 4 * hidden_size - ds_config.max_seq_length = seq_len ds_config.heads = heads ds_config.attn_dropout_ratio = 0.0 ds_config.hidden_dropout_ratio = 0.0 From 7435b2f10af773b0204e77c3549b2b7df9a7a65b Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Thu, 17 Dec 2020 23:17:19 -0800 Subject: [PATCH 10/17] Ability to initialize distributed backend outside deepspeed runtime (#608) --- DeepSpeedExamples | 2 +- deepspeed/__init__.py | 1 + deepspeed/constants.py | 8 ++ deepspeed/launcher/constants.py | 5 -- deepspeed/launcher/launch.py | 2 +- deepspeed/launcher/runner.py | 4 +- deepspeed/runtime/constants.py | 5 -- deepspeed/runtime/engine.py | 108 ++---------------------- deepspeed/utils/__init__.py | 3 +- deepspeed/utils/distributed.py | 129 +++++++++++++++++++++++++++++ docs/_tutorials/getting-started.md | 30 +++---- install.sh | 2 +- tests/unit/common.py | 12 ++- 13 files changed, 175 insertions(+), 136 deletions(-) create mode 100644 deepspeed/constants.py create mode 100644 deepspeed/utils/distributed.py diff --git a/DeepSpeedExamples b/DeepSpeedExamples index abb270641ca8..78d69cb2f89a 160000 --- a/DeepSpeedExamples +++ b/DeepSpeedExamples @@ -1 +1 @@ -Subproject commit abb270641ca8c33476282bde29916c395a060ae9 +Subproject commit 78d69cb2f89a27b1e9b072df8c3e47d00c024fdc diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 8ac0aad05562..ba6f9b5bb6bf 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -14,6 +14,7 @@ from .runtime.activation_checkpointing import checkpointing from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig from .utils import log_dist +from .utils.distributed import init_distributed from .pipe import PipelineModule diff --git a/deepspeed/constants.py b/deepspeed/constants.py new file mode 100644 index 000000000000..467e85aefcb6 --- /dev/null +++ b/deepspeed/constants.py @@ -0,0 +1,8 @@ +''' +Copyright 2020 The Microsoft DeepSpeed Team +''' + +############################################# +# Torch distributed constants +############################################# +TORCH_DISTRIBUTED_DEFAULT_PORT = 29500 diff --git a/deepspeed/launcher/constants.py b/deepspeed/launcher/constants.py index f384d58b2c52..fd56facc4343 100644 --- a/deepspeed/launcher/constants.py +++ b/deepspeed/launcher/constants.py @@ -1,10 +1,5 @@ # Copyright 2020 The Microsoft DeepSpeed Team -############################################# -# Torch distributed constants -############################################# -TORCH_DISTRIBUTED_DEFAULT_PORT = 29500 - PDSH_LAUNCHER = 'pdsh' PDSH_MAX_FAN_OUT = 1024 diff --git a/deepspeed/launcher/launch.py b/deepspeed/launcher/launch.py index 74a0530c7f98..0958295efe06 100755 --- a/deepspeed/launcher/launch.py +++ b/deepspeed/launcher/launch.py @@ -16,7 +16,7 @@ from collections import defaultdict from argparse import ArgumentParser, REMAINDER -from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT +from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT from ..utils import logger diff --git a/deepspeed/launcher/runner.py b/deepspeed/launcher/runner.py index 9479bb63758c..eb03502cc3f2 100755 --- a/deepspeed/launcher/runner.py +++ b/deepspeed/launcher/runner.py @@ -19,8 +19,8 @@ import torch.cuda from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner -from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT, \ - PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER +from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER +from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT from ..utils import logger DLTS_HOSTFILE = "/job/hostfile" diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index a731865714fe..c56c3898f60f 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -73,11 +73,6 @@ ZERO_ALLOW_UNTESTED_OPTIMIZER = "zero_allow_untested_optimizer" ZERO_ALLOW_UNTESTED_OPTIMIZER_DEFAULT = False -############################################# -# Torch distributed constants -############################################# -TORCH_DISTRIBUTED_DEFAULT_PORT = "29500" - # Steps STEPS_PER_PRINT = "steps_per_print" STEPS_PER_PRINT_DEFAULT = 10 diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 76ba6af78b76..49e1bedd3cfc 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -24,12 +24,12 @@ from deepspeed.runtime.dataloader import DeepSpeedDataLoader from deepspeed.runtime.constants import \ ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \ - TORCH_DISTRIBUTED_DEFAULT_PORT, PLD_THETA, PLD_GAMMA + PLD_THETA, PLD_GAMMA from deepspeed.runtime.zero.constants import \ ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS from deepspeed.runtime.csr_tensor import CSRTensor import deepspeed.runtime.lr_schedules as lr_schedules -from deepspeed.utils import logger, log_dist +from deepspeed.utils import logger, log_dist, init_distributed from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop @@ -130,29 +130,14 @@ def __init__(self, if dist_init_required is False: assert (dist.is_initialized()==True), "Torch distributed not initialized. Please set dist_init_required to True or initialize before calling deepspeed.initialize()" - # DeepSpeed will initialize torch distributed only if the user has not already intialized it. - if dist_init_required and not dist.is_initialized(): - # discover using mpi4py if user specifies the flag - if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi: - # if in Azure ML environment and user specified this flag, notify the user to remove the flag. - if self._in_aml(): - logger.warning( - "Please remove the --deepspeed_mpi flag if running on AzureML.") - self._mpi_check(args, dist_init_required) - else: - # detect if we are in Azure ML environment - if self._in_aml(): - self._set_environment_variables_for_nccl_backend(args) - - logger.info("Initializing torch distributed with backend: {}".format( - self.dist_backend)) - dist.init_process_group(backend=self.dist_backend) + # Initialize torch distributed if needed + init_distributed(dist_backend=self.dist_backend) self._do_args_sanity_check(args) self._configure_with_arguments(args, mpu) self._do_sanity_check() - self._init_distributed(dist_init_required) + self._set_distributed_vars() if self.tensorboard_enabled() and self.global_rank == 0: self.summary_writer = self.get_summary_writer() @@ -209,87 +194,6 @@ def __init__(self, self.flatten = util_ops.flatten self.unflatten = util_ops.unflatten - def _in_aml(self): - # read AzureML environment variable to detect if we are using an Azure ML environment - if 'AZUREML_EXPERIMENT_ID' in os.environ: - return True - else: - return False - - def _set_environment_variables_for_nccl_backend(self, - args, - master_port=6105, - verbose=True): - """Helper routine to get and set environment variables. - This is adapted from Azure ML's documentation available from: - https://azure.github.io/azureml-web/docs/cheatsheet/distributed-training/#environment-variables-from-openmpi - """ - os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"] - os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"] - single_node = int(os.environ["OMPI_COMM_WORLD_LOCAL_SIZE"]) == int( - os.environ["WORLD_SIZE"]) - if not single_node: - master_node_params = os.environ["AZ_BATCH_MASTER_NODE"].split(":") - os.environ["MASTER_ADDR"] = master_node_params[0] - # Do not overwrite master port with that defined in AZ_BATCH_MASTER_NODE - if "MASTER_PORT" not in os.environ: - os.environ["MASTER_PORT"] = str(master_port) - else: - os.environ["MASTER_ADDR"] = os.environ["AZ_BATCHAI_MPI_MASTER_NODE"] - os.environ["MASTER_PORT"] = "54965" - print("NCCL_SOCKET_IFNAME original value = {}".format( - os.environ["NCCL_SOCKET_IFNAME"])) - - os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo" - args.local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) - - if verbose: - logger.info( - "Discovered AzureML settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}" - .format(os.environ['RANK'], - args.local_rank, - os.environ['WORLD_SIZE'], - os.environ['MASTER_ADDR'], - os.environ['MASTER_PORT'])) - - def _mpi_check(self, args, dist_init_required): - from mpi4py import MPI - import subprocess - comm = MPI.COMM_WORLD - rank = comm.Get_rank() - world_size = comm.Get_size() - - master_addr = None - if rank == 0: - hostname_cmd = ["hostname -I"] - result = subprocess.check_output(hostname_cmd, shell=True) - master_addr = result.decode('utf-8').split()[0] - master_addr = comm.bcast(master_addr, root=0) - - # Determine local rank by assuming hostnames are unique - proc_name = MPI.Get_processor_name() - all_procs = comm.allgather(proc_name) - local_rank = sum([i == proc_name for i in all_procs[:rank]]) - - os.environ['RANK'] = str(rank) - os.environ['WORLD_SIZE'] = str(world_size) - args.local_rank = local_rank - os.environ['MASTER_ADDR'] = master_addr - os.environ['MASTER_PORT'] = TORCH_DISTRIBUTED_DEFAULT_PORT - - logger.info( - "Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}" - .format(os.environ['RANK'], - args.local_rank, - os.environ['WORLD_SIZE'], - os.environ['MASTER_ADDR'], - os.environ['MASTER_PORT'])) - - if not dist_init_required and dist.is_initialized(): - assert dist.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, dist.get_rank()) - assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format( - world_size, dist.get_world_size()) - def pld_enabled(self): return self._config.pld_enabled @@ -497,7 +401,7 @@ def _scheduler_from_config(self, optimizer): else: return None - def _init_distributed(self, dist_init_required): + def _set_distributed_vars(self): if self.local_rank >= 0: torch.cuda.set_device(self.local_rank) self.device = torch.device("cuda", self.local_rank) diff --git a/deepspeed/utils/__init__.py b/deepspeed/utils/__init__.py index 37517764b375..c231edca4919 100644 --- a/deepspeed/utils/__init__.py +++ b/deepspeed/utils/__init__.py @@ -1,2 +1,3 @@ -from deepspeed.utils.logging import logger, log_dist +from .logging import logger, log_dist +from .distributed import init_distributed from deepspeed.runtime.dataloader import RepeatingLoader diff --git a/deepspeed/utils/distributed.py b/deepspeed/utils/distributed.py new file mode 100644 index 000000000000..e70f00b440bb --- /dev/null +++ b/deepspeed/utils/distributed.py @@ -0,0 +1,129 @@ +''' +Copyright 2020 The Microsoft DeepSpeed Team +''' +import os +import torch + +from .logging import logger +from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT + + +def init_distributed(dist_backend="nccl", + auto_mpi_discovery=True, + distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, + verbose=True): + """ + Initialize torch.distributed backend, potentially performing MPI discovery if needed + Arguments: + dist_backend (str): torch distributed backend, e.g., nccl, mpi, gloo + auto_mpi_discovery (bool): if distributed environment variables are not set, attempt to discover them from MPI + distributed_port (int, optional): torch distributed backend port + verbose (bool, optional): verbose logging + """ + + required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] + if auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)): + if verbose: + logger.info( + "Not using the DeepSpeed or torch.distributed launchers, attempting to detect MPI environment..." + ) + if in_aml() and not in_dlts(): + patch_aml_env_for_torch_nccl_backend(verbose=verbose) + else: + mpi_discovery(distributed_port=distributed_port, verbose=verbose) + + if not torch.distributed.is_initialized(): + if verbose: + logger.info( + "Initializing torch distributed with backend: {}".format(dist_backend)) + torch.distributed.init_process_group(backend=dist_backend) + + +def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True): + """ + Discovery MPI environment via mpi4py and map to relevant torch.distributed state + """ + from mpi4py import MPI + import subprocess + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + world_size = comm.Get_size() + + master_addr = None + if rank == 0: + hostname_cmd = ["hostname -I"] + result = subprocess.check_output(hostname_cmd, shell=True) + master_addr = result.decode('utf-8').split()[0] + master_addr = comm.bcast(master_addr, root=0) + + # Determine local rank by assuming hostnames are unique + proc_name = MPI.Get_processor_name() + all_procs = comm.allgather(proc_name) + local_rank = sum([i == proc_name for i in all_procs[:rank]]) + + os.environ['RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['LOCAL_RANK'] = str(local_rank) + os.environ['MASTER_ADDR'] = master_addr + os.environ['MASTER_PORT'] = str(distributed_port) + + if verbose: + logger.info( + "Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}" + .format(os.environ['RANK'], + os.environ['LOCAL_RANK'], + os.environ['WORLD_SIZE'], + os.environ['MASTER_ADDR'], + os.environ['MASTER_PORT'])) + + if torch.distributed.is_initialized(): + assert dist.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, dist.get_rank()) + assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format( + world_size, dist.get_world_size()) + + +def in_aml(): + # Are we running inside an Azure Machine Learning (AML) environment? + return 'AZUREML_EXPERIMENT_ID' in os.environ + + +def in_dlts(): + # Are we running on a DLTS cluster? + return 'DLTS_JOB_ID' in os.environ + + +def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True): + """Helper routine to get and set environment variables. + This is adapted from Azure ML's documentation available from: + https://azure.github.io/azureml-web/docs/cheatsheet/distributed-training/#environment-variables-from-openmpi + """ + os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"] + os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"] + single_node = int(os.environ["OMPI_COMM_WORLD_LOCAL_SIZE"]) == int( + os.environ["WORLD_SIZE"]) + + if not single_node: + master_node_params = os.environ["AZ_BATCH_MASTER_NODE"].split(":") + os.environ["MASTER_ADDR"] = master_node_params[0] + # Do not overwrite master port with that defined in AZ_BATCH_MASTER_NODE + if "MASTER_PORT" not in os.environ: + os.environ["MASTER_PORT"] = str(master_port) + else: + os.environ["MASTER_ADDR"] = os.environ["AZ_BATCHAI_MPI_MASTER_NODE"] + os.environ["MASTER_PORT"] = "54965" + + if verbose: + logger.info("NCCL_SOCKET_IFNAME original value = {}".format( + os.environ["NCCL_SOCKET_IFNAME"])) + + os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo" + os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"] + + if verbose: + logger.info( + "Discovered AzureML settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}" + .format(os.environ['RANK'], + os.environ['LOCAL_RANK'], + os.environ['WORLD_SIZE'], + os.environ['MASTER_ADDR'], + os.environ['MASTER_PORT'])) diff --git a/docs/_tutorials/getting-started.md b/docs/_tutorials/getting-started.md index 1f23c64d4085..21268802d6c8 100644 --- a/docs/_tutorials/getting-started.md +++ b/docs/_tutorials/getting-started.md @@ -216,25 +216,27 @@ DeepSpeed will then make sure that these environment variables are set when launching each process on every node across their training job. -### MPI Compatibility +### MPI and AzureML Compatibility As described above, DeepSpeed provides its own parallel launcher to help launch multi-node/multi-gpu training jobs. If you prefer to launch your training job using MPI (e.g., mpirun), we provide support for this. It should be noted that DeepSpeed will still use the torch distributed NCCL backend and *not* the MPI -backend. To launch your training job with mpirun + DeepSpeed you simply pass us -an additional flag `--deepspeed_mpi`. DeepSpeed will then use -[mpi4py](https://pypi.org/project/mpi4py/) to discover the MPI environment (e.g., -rank, world size) and properly initialize torch distributed for training. In this -case you will explicitly invoke `python` to launch your model script instead of using -the `deepspeed` launcher, here is an example: -```bash -mpirun python \ - \ - --deepspeed_mpi --deepspeed --deepspeed_config ds_config.json -``` +backend. + +To launch your training job with mpirun + DeepSpeed or with AzureML (which uses +mpirun as a launcher backend) you simply need to install the +[mpi4py](https://pypi.org/project/mpi4py/) python package. DeepSpeed will use +this to discover the MPI environment and pass the necessary state (e.g., world +size, rank) to the torch distributed backend. -If you want to use this feature of DeepSpeed, please ensure that mpi4py is -installed via `pip install mpi4py`. +If you are using model parallelism, pipeline parallelism, or otherwise require +torch.distributed calls before calling `deepspeed.initialize(..)` we provide +the same MPI support with an additional DeepSpeed API call. Replace your initial +`torch.distributed.init_process_group(..)` call with: + +```python +deepspeed.init_distributed() +``` ## Resource Configuration (single-node) In the case that we are only running on a single node (with one or more GPUs) diff --git a/install.sh b/install.sh index b027d319cdd6..b9f1501d9cad 100755 --- a/install.sh +++ b/install.sh @@ -171,5 +171,5 @@ else pdcp -w $hosts dist/deepspeed*.whl $tmp_wheel_path/ pdsh -w $hosts "$PIP_SUDO $PIP_INSTALL $tmp_wheel_path/deepspeed*.whl" pdsh -w $hosts "ds_report" - pdsh -w $hosts "if [ -d $tmp_wheel_path ]; then rm $tmp_wheel_path/*.whl; rmdir $tmp_wheel_path; fi" + pdsh -w $hosts "if [ -d $tmp_wheel_path ]; then rm $tmp_wheel_path/*.whl; rm $tmp_wheel_path/*.txt; rmdir $tmp_wheel_path; fi" fi diff --git a/tests/unit/common.py b/tests/unit/common.py index 73d7957e29f9..62b7495a025c 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -5,6 +5,8 @@ import torch.distributed as dist from torch.multiprocessing import Process +import deepspeed + import pytest # Worker timeout *after* the first worker has completed. @@ -33,10 +35,12 @@ def dist_init(local_rank, num_procs, *func_args, **func_kwargs): """Initialize torch.distributed and execute the user function. """ os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = '29503' - dist.init_process_group(backend=backend, - init_method='env://', - rank=local_rank, - world_size=num_procs) + os.environ['LOCAL_RANK'] = str(local_rank) + # NOTE: unit tests don't support multi-node so local_rank == global rank + os.environ['RANK'] = str(local_rank) + os.environ['WORLD_SIZE'] = str(num_procs) + + deepspeed.init_distributed(dist_backend=backend) if torch.cuda.is_available(): torch.cuda.set_device(local_rank) From 81aeea361da3936b875a678b9cb44596800510b5 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Tue, 22 Dec 2020 22:26:26 -0800 Subject: [PATCH 11/17] Elastic training support (#602) Co-authored-by: Samyam Rajbhandari --- .github/workflows/main.yml | 10 +- bin/ds_elastic | 39 ++++ deepspeed/elasticity/__init__.py | 1 + deepspeed/elasticity/config.py | 80 +++++++ deepspeed/elasticity/constants.py | 74 +++++++ deepspeed/elasticity/elasticity.py | 334 +++++++++++++++++++++++++++++ deepspeed/runtime/config.py | 75 ++++++- deepspeed/runtime/config_utils.py | 4 + deepspeed/runtime/engine.py | 31 ++- deepspeed/runtime/pipe/engine.py | 2 + op_builder/builder.py | 4 +- requirements/requirements.txt | 1 + setup.py | 3 +- tests/unit/test_checkpointing.py | 4 +- tests/unit/test_elastic.py | 241 +++++++++++++++++++++ version.txt | 2 +- 16 files changed, 883 insertions(+), 22 deletions(-) create mode 100644 bin/ds_elastic create mode 100644 deepspeed/elasticity/__init__.py create mode 100644 deepspeed/elasticity/config.py create mode 100644 deepspeed/elasticity/constants.py create mode 100644 deepspeed/elasticity/elasticity.py create mode 100644 tests/unit/test_elastic.py diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 794adeb7ab00..173a51cda5de 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -4,14 +4,12 @@ name: Build # Controls when the action will run. on: - # Triggers the workflow on push or pull request events but only for the master branch push: - branches: [ master ] + paths-ignore: + - 'docs/**' pull_request: - branches: [ master ] - - # Allows you to run this workflow manually from the Actions tab - workflow_dispatch: + paths-ignore: + - 'docs/**' # A workflow run is made up of one or more jobs that can run sequentially or in parallel jobs: diff --git a/bin/ds_elastic b/bin/ds_elastic new file mode 100644 index 000000000000..ef92cbdab32d --- /dev/null +++ b/bin/ds_elastic @@ -0,0 +1,39 @@ +#!/usr/bin/env python + +import argparse +import json + +import deepspeed +from deepspeed.elasticity import compute_elastic_config + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-c', '--config', type=str, help="DeepSpeed config json") + parser.add_argument('-w', '--world-size', type=int, default=0, help="Intended/current world size") + args = parser.parse_args() + ds_config = json.load(open(args.config, 'r')) + + ds_version = deepspeed.__version__ + + elastic_config = ds_config['elasticity'] + print('------------------------------------------') + print("Elasticity config:") + print('------------------------------------------') + print(json.dumps(elastic_config, indent=4, sort_keys=True)) + + if args.world_size > 0: + final_batch_size, valid_gpus, micro_batch_size = compute_elastic_config(ds_config=ds_config, target_deepspeed_version=ds_version, world_size=args.world_size) + print('------------------------------------------') + print(f"Calculated results for world size {args.world_size}:") + print('------------------------------------------') + print(f'final_batch_size .... {final_batch_size}') + print(f'valid_gpus .......... {valid_gpus}') + print(f'micro_batch_size .... {micro_batch_size}') + else: + final_batch_size, valid_gpus = compute_elastic_config(ds_config=ds_config, target_deepspeed_version=ds_version) + print('------------------------------------------') + print("Calculated results:") + print('------------------------------------------') + print(f'final_batch_size .... {final_batch_size}') + print(f'valid_gpus .......... {valid_gpus}') diff --git a/deepspeed/elasticity/__init__.py b/deepspeed/elasticity/__init__.py new file mode 100644 index 000000000000..be517de7df93 --- /dev/null +++ b/deepspeed/elasticity/__init__.py @@ -0,0 +1 @@ +from .elasticity import compute_elastic_config, elasticity_enabled, ensure_immutable_elastic_config diff --git a/deepspeed/elasticity/config.py b/deepspeed/elasticity/config.py new file mode 100644 index 000000000000..dda56d72882c --- /dev/null +++ b/deepspeed/elasticity/config.py @@ -0,0 +1,80 @@ +""" +Copyright 2020 The Microsoft DeepSpeed Team +""" + +import json +from .constants import * + + +class ElasticityError(Exception): + """ + Base exception for all elasticity related errors + """ + pass + + +class ElasticityConfigError(ElasticityError): + """ + Elasticity configuration error + """ + pass + + +class ElasticityIncompatibleWorldSize(ElasticityError): + """ + Attempting to run a world size that is incompatible with a given elastic config + """ + pass + + +class ElasticityConfig: + """ + Elastic config object, constructed from a param dictionary that only contains elastic + config parameters, example below: + + If elasticity is enabled, user must specify (at least) max_train_batch_size + and micro_batch_sizes. + + { + "enabled": true, + "max_train_batch_size": 2000, + "micro_batch_sizes": [2,4,6], + "min_gpus": 1, + "max_gpus" : 10000 + "min_time": 20 + "ignore_non_elastic_batch_info": false + "version": 0.1 + } + """ + def __init__(self, param_dict): + self.enabled = param_dict.get(ENABLED, ENABLED_DEFAULT) + if self.enabled: + if MAX_ACCEPTABLE_BATCH_SIZE in param_dict: + self.max_acceptable_batch_size = param_dict[MAX_ACCEPTABLE_BATCH_SIZE] + else: + raise ElasticityConfigError( + f"Elasticity config missing {MAX_ACCEPTABLE_BATCH_SIZE}") + if MICRO_BATCHES in param_dict: + self.micro_batches = param_dict[MICRO_BATCHES] + else: + raise ElasticityConfigError(f"Elasticity config missing {MICRO_BATCHES}") + else: + self.max_acceptable_batch_size = param_dict.get( + MAX_ACCEPTABLE_BATCH_SIZE, + MAX_ACCEPTABLE_BATCH_SIZE_DEFAULT) + self.micro_batches = param_dict.get(MICRO_BATCHES, MICRO_BATCHES_DEFAULT) + self.min_gpus = param_dict.get(MIN_GPUS, MIN_GPUS_DEFAULT) + self.max_gpus = param_dict.get(MAX_GPUS, MAX_GPUS_DEFAULT) + self.min_time = param_dict.get(MIN_TIME, MIN_TIME_DEFAULT) + self.version = param_dict.get(VERSION, VERSION_DEFAULT) + self.prefer_larger_batch_size = param_dict.get(PREFER_LARGER_BATCH, + PREFER_LARGER_BATCH_DEFAULT) + self.ignore_non_elastic_batch_info = param_dict.get( + IGNORE_NON_ELASTIC_BATCH_INFO, + IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT) + + def repr(self): + return self.__dict__ + + def __repr__(self): + return json.dumps(self.__dict__, sort_keys=True, indent=4) diff --git a/deepspeed/elasticity/constants.py b/deepspeed/elasticity/constants.py new file mode 100644 index 000000000000..7db563a83de2 --- /dev/null +++ b/deepspeed/elasticity/constants.py @@ -0,0 +1,74 @@ +""" +Copyright 2020 The Microsoft DeepSpeed Team +""" + +######################################### +# Elasticity +######################################### +''' Elasticity Utility in DeepSpeed can be used to create highly elastic jobs compatible +with a large number of GPUs. For elastic jobs, DeepSpeed will provide a batch size that +can support a large number of GPUs based on the user specified parameters +''' +FORMAT = ''' +Elasticity should be enabled as: +"elasticity": { + "enabled": true, + "max_train_batch_size": 2000, + "micro_batch_sizes": [2,4,6], + "min_gpus": 1, + "max_gpus" : 10000 + "min_time": 20, + "prefer_larger_batch": true, + "ignore_non_elastic_batch_info": false, + "version": 0.1 +} +''' + +ELASTICITY = 'elasticity' + +# Current elasticity version +LATEST_ELASTICITY_VERSION = 0.1 + +ENABLED = 'enabled' +ENABLED_DEFAULT = False + +# Max acceptable train_batch_size +MAX_ACCEPTABLE_BATCH_SIZE = 'max_train_batch_size' +MAX_ACCEPTABLE_BATCH_SIZE_DEFAULT = 2000 + +# Acceptable micro batch sizes, same as train_micro_batch_size_per_gpu +MICRO_BATCHES = 'micro_batch_sizes' +MICRO_BATCHES_DEFAULT = [2, 4, 6] + +# Min/max of GPUs to search over +MIN_GPUS = 'min_gpus' +MIN_GPUS_DEFAULT = 1 +MAX_GPUS = 'max_gpus' +MAX_GPUS_DEFAULT = 10000 + +# Minimum running time (minutes) before the scheduler will scale us +MIN_TIME = "min_time" +MIN_TIME_DEFAULT = "20" + +# When finding a suitable batch size, attempt to find one that is closest +# to the max train batch size given. +PREFER_LARGER_BATCH = 'prefer_larger_batch' +PREFER_LARGER_BATCH_DEFAULT = True + +# In order to reduce confusion, if elastic mode is enabled we +# require (via assert) that no batch info is set outside of the +# elastic config. You can turn off this assert via this config +# but keep in mind that all batch info defined outside the +# elastic mode *will be ignored*. +IGNORE_NON_ELASTIC_BATCH_INFO = 'ignore_non_elastic_batch_info' +IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT = False + +# Version of elastic logic to use +VERSION = "version" +VERSION_DEFAULT = LATEST_ELASTICITY_VERSION + +# Minimum deepspeed version to use elasticity +MINIMUM_DEEPSPEED_VERSION = "0.3.8" + +# Environment variable storing elastic config from resource scheduler +DEEPSPEED_ELASTICITY_CONFIG = "DEEPSPEED_ELASTICITY_CONFIG" diff --git a/deepspeed/elasticity/elasticity.py b/deepspeed/elasticity/elasticity.py new file mode 100644 index 000000000000..ae91877f5f24 --- /dev/null +++ b/deepspeed/elasticity/elasticity.py @@ -0,0 +1,334 @@ +""" +Copyright 2020 The Microsoft DeepSpeed Team +""" +import os +import re +import json +import numpy as np + +from .config import ElasticityConfig, ElasticityConfigError, ElasticityError, \ + ElasticityIncompatibleWorldSize +from .constants import ELASTICITY, ENABLED, ENABLED_DEFAULT, LATEST_ELASTICITY_VERSION, \ + MINIMUM_DEEPSPEED_VERSION, IGNORE_NON_ELASTIC_BATCH_INFO, \ + IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT, DEEPSPEED_ELASTICITY_CONFIG +from ..git_version_info import version as __version__ +from ..utils import logger + +# Thirty eight smallest highly composite numbers. The list should +# be enough to support up to 720K batch size. +HCN_LIST = [ + 1, + 2, + 4, + 6, + 12, + 24, + 36, + 48, + 60, + 120, + 180, + 240, + 360, + 720, + 840, + 1260, + 1680, + 2520, + 5040, + 7560, + 10080, + 15120, + 20160, + 25200, + 27720, + 45360, + 50400, + 55440, + 83160, + 110880, + 166320, + 221760, + 277200, + 332640, + 498960, + 554400, + 665280, + 720720 +] + + +def get_candidate_batch_sizes(base_list, max_acceptable_batch_size): + candidate_batch_size = [] + + #brute force is fine here. We are working with very small lists + for base in base_list: + batch_size = base + for hcn in HCN_LIST: + new_batch_size = base * hcn + if new_batch_size > max_acceptable_batch_size: + break + batch_size = new_batch_size + candidate_batch_size.append(batch_size) + return list(set(candidate_batch_size)) + + +def get_valid_gpus(batch_size, micro_batches, min_valid_gpus, max_valid_gpus): + valid_gpus = [] + for micro_batch in micro_batches: + if batch_size % micro_batch == 0: + + max_gpus = batch_size // micro_batch + if max_gpus >= min_valid_gpus and max_gpus <= max_valid_gpus: + valid_gpus.append(max_gpus) + + for i in range(1, max_gpus // 2 + 1): + if max_gpus % i == 0: + if i >= min_valid_gpus and i <= max_valid_gpus: + valid_gpus.append(i) + valid_gpus = set(valid_gpus) + valid_gpus = sorted(list(valid_gpus)) + return valid_gpus + + +def get_best_candidates(candidate_batch_sizes, + micro_batches, + min_gpus, + max_gpus, + prefer_larger): + + max_valid_gpus = 0 + valid_gpus = None + final_batch_size = int(min(micro_batches)) + + for batch_size in candidate_batch_sizes: + + current_valid_gpus = get_valid_gpus(batch_size, + micro_batches, + min_gpus, + max_gpus) + + if (len(current_valid_gpus) > max_valid_gpus + or (len(current_valid_gpus) == max_valid_gpus and + ((prefer_larger and batch_size > final_batch_size) or + (not prefer_larger and batch_size < final_batch_size)))): + max_valid_gpus = len(current_valid_gpus) + valid_gpus = current_valid_gpus + final_batch_size = batch_size + + return final_batch_size, valid_gpus + + +def _get_compatible_gpus_v01(micro_batches, + max_acceptable_batch_size, + min_gpus=None, + max_gpus=None, + prefer_larger=True): + '''We use two heuristics to compute the batch size + 1. We use the Lowest Common Multiple of the micro-batches + as the base batch size and scale it by a HCN such that the result is + the largest batch size less than the max_acceptable batch size + 2. We use each of the micro batches as a base and scale it + by a HCN such that the result is the largest batch size less than the + max_acceptable batch size. + + We then use brute force to count the number of compatible GPU count for + each of the aforementioned cases, and return the batch size with the most number of + compatible GPU counts in the min-max GPU range if provided, other wise + we return the batch size with the most number of total compatible GPU counts. + + Returns: + final_batch_size + valid_gpus + ''' + + if min_gpus is None: + min_gpus = int(1) + + if max_gpus is None: + max_gpus = int(max_acceptable_batch_size / min(micro_batches)) + + assert all(mb <= max_acceptable_batch_size for mb in micro_batches ), \ + f"All micro batches must be less than \ + or equal to max_acceptable_batch_size: {max_acceptable_batch_size}" + + lcm = np.lcm.reduce(micro_batches) + + base_list = [] + base_list.extend(micro_batches) + base_list.append(lcm) + + candidate_batch_sizes = get_candidate_batch_sizes(base_list, + max_acceptable_batch_size) + + final_batch_size, valid_gpus = get_best_candidates( + candidate_batch_sizes, + micro_batches, + min_gpus, + max_gpus, + prefer_larger) + + return final_batch_size, valid_gpus + + +def _parse_version(version_str): + '''Parse a version string and extract the major and minor versions (and possibly patch version).''' + matched = re.search('^(\d+)\.(\d+)\.(\d+)', version_str) + if matched: + return int(matched.group(1)), int(matched.group(2)), int(matched.group(3)) + else: + matched = re.search('^(\d+)\.(\d+)', version_str) + assert matched != None, "Unable to parse version number, expecting" \ + f"major.minor[.patch] format but received {version_str}" + return int(matched.group(1)), int(matched.group(2)), 0 + + +def _compatible_ds_version_check(target_deepspeed_version: str): + min_major, min_minor, min_patch = _parse_version(MINIMUM_DEEPSPEED_VERSION) + trg_major, trg_minor, trg_patch = _parse_version(target_deepspeed_version) + + err_str = f"Target deepspeed version of {target_deepspeed_version} is not compatible " \ + f"with minimum version {MINIMUM_DEEPSPEED_VERSION} supporting elasticity." + if trg_major < min_major: + raise ElasticityError(err_str) + if trg_minor < min_minor: + raise ElasticityError(err_str) + if trg_patch < min_patch: + raise ElasticityError(err_str) + return True + + +def elasticity_enabled(ds_config: dict): + if ELASTICITY not in ds_config: + return False + return ds_config[ELASTICITY].get(ENABLED, ENABLED_DEFAULT) + + +def ensure_immutable_elastic_config(runtime_elastic_config_dict: dict): + """ + Ensure the resource scheduler saw the same elastic config we are using at runtime + """ + if DEEPSPEED_ELASTICITY_CONFIG in os.environ: + scheduler_elastic_config_dict = json.loads( + os.environ[DEEPSPEED_ELASTICITY_CONFIG]) + scheduler_elastic_config = ElasticityConfig(scheduler_elastic_config_dict) + runtime_elastic_config = ElasticityConfig(runtime_elastic_config_dict) + err_str = "Elastic config '{}={}' seen by resource scheduler does not match config passed to runtime {}={}" + if runtime_elastic_config.max_acceptable_batch_size != scheduler_elastic_config.max_acceptable_batch_size: + raise ElasticityConfigError( + err_str.format('max_acceptable_batch_size', + scheduler_elastic_config.max_acceptable_batch_size, + 'max_acceptable_batch_size', + runtime_elastic_config.max_acceptable_batch_size)) + if runtime_elastic_config.micro_batches != scheduler_elastic_config.micro_batches: + raise ElasticityConfigError( + err_str.format('micro_batches', + scheduler_elastic_config.micro_batches, + 'micro_batches', + runtime_elastic_config.micro_batches)) + if runtime_elastic_config.version != scheduler_elastic_config.version: + raise ElasticityConfigError( + err_str.format('version', + scheduler_elastic_config.version, + 'version', + runtime_elastic_config.version)) + else: + logger.warning("Unable to find DEEPSPEED_ELASTICITY_CONFIG environment variable, cannot " \ + "guarantee resource scheduler will scale this job using compatible GPU counts.") + + +def compute_elastic_config(ds_config: dict, target_deepspeed_version: str, world_size=0): + """Core deepspeed elasticity API. Given an elastic config (similar to the example below) + DeepSpeed will compute a total train batch size corresponding valid GPU count list that + provides a high level of elasticity. Elasticity in this case means we are safe to scale + the training job up/down across the GPU count list *without* any negative impacts on + training convergence. This is achievable primarily due to DeepSpeed's gradient accumulation + feature which allows us to decompose a global training batch size into: + micro-batch-size * gradient-accumulation-steps * world-size. + + "elasticity": { + "enabled": true, + "max_train_batch_size": 2000, + "micro_batch_sizes": [2,4,6], + "min_gpus": 1, + "max_gpus" : 10000 + "min_time": 20 + "version": 0.1 + } + + Intended to be called both by scheduling infrastructure and deepspeed runtime. + For the same `ds_config` we should return deterministic results. + + Args: + ds_config (dict): DeepSpeed config dictionary/json + target_deepspeed_version (str): When called from scheduling + infrastructure we want to ensure that the target deepspeed version is + compatible with the elasticity version used in the backend. + world_size (int, optional): Intended/current world size, will do some sanity + checks to ensure world size is actually valid with the config. + + Raises: + ElasticityConfigError: Missing required elasticity config or elasticity disabled + ElasticityError: If target deepspeed version is not compatible with current version + + Returns: + final_batch_size (int): total batch size used for training + valid_gpus (list(int)): list of valid GPU counts with this config + micro_batch_size (int, optional): if world_size is provided will return + specific micro batch size + """ + if not isinstance(ds_config, dict): + raise ValueError("Expected ds_config to be a dictionary but received " \ + f"a {type(ds_config)}, containing: {ds_config}") + + if ELASTICITY not in ds_config: + raise ElasticityConfigError(f"'{ELASTICITY}' is missing from config json," \ + " please add it if running an elastic training job.") + + elastic_config_dict = ds_config[ELASTICITY] + if not elastic_config_dict.get(ENABLED, ENABLED_DEFAULT): + raise ElasticityConfigError("Elasticity is disabled, please enable it " \ + "('enabled':true) if running an elastic training job.") + + elastic_config = ElasticityConfig(elastic_config_dict) + + if float(elastic_config.version) > LATEST_ELASTICITY_VERSION: + raise ElasticityConfigError("Attempting to run elasticity version " \ + f"{elastic_config.version} but runtime only supports up " \ + f"to {LATEST_ELASTICITY_VERSION}") + + # Ensure target deepspeed version works with intended elasticity version + if not _compatible_ds_version_check(target_deepspeed_version): + raise ElasticityError("Unable to run elasticity on target deepspeed version of" \ + f" {target_deepspeed_version}, currently {__version__}") + + if float(elastic_config.version) == 0.1: + final_batch_size, valid_gpus = _get_compatible_gpus_v01( + micro_batches=elastic_config.micro_batches, + max_acceptable_batch_size=elastic_config.max_acceptable_batch_size, + min_gpus=elastic_config.min_gpus, + max_gpus=elastic_config.max_gpus, + prefer_larger=elastic_config.prefer_larger_batch_size) + # ensure batch size is int dtype + final_batch_size = int(final_batch_size) + else: + raise NotImplementedError( + f"Unable to find elastic logic for version: {elastic_config.version}") + + if world_size > 0: + if world_size not in valid_gpus: + raise ElasticityIncompatibleWorldSize(f"World size ({world_size}) is not valid " \ + f"with the current list of valid GPU counts: {valid_gpus}") + + # Pick largest valid micro batch size + micro_batch_size = None + for mbsz in sorted(list(set(elastic_config.micro_batches)), reverse=True): + if final_batch_size // world_size % mbsz == 0: + micro_batch_size = mbsz + break + assert micro_batch_size is not None, "Unable to find divisible micro batch size" \ + f" world_size={world_size}, final_batch_size={final_batch_size}, and " \ + f" micro_batches={elastic_config.micro_batches}." + return final_batch_size, valid_gpus, micro_batch_size + + return final_batch_size, valid_gpus diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 4a56aafbc539..9d52dfe6d766 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -6,13 +6,21 @@ import torch import json import copy -from deepspeed.runtime.constants import * -from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE -from deepspeed.runtime.config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys -from deepspeed.runtime.zero.config import DeepSpeedZeroConfig -from deepspeed.runtime.zero.constants import * -from deepspeed.runtime.activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig -from deepspeed.utils import logger + +from .constants import * +from .fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE +from .config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys +from .zero.config import DeepSpeedZeroConfig +from .zero.constants import * +from .activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig + +from ..git_version_info import version as __version__ +from ..utils import logger + +from ..elasticity import elasticity_enabled, compute_elastic_config, ensure_immutable_elastic_config +from ..elasticity.config import ElasticityConfigError +from ..elasticity.constants import ELASTICITY, IGNORE_NON_ELASTIC_BATCH_INFO, \ + IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT TENSOR_CORE_ALIGN_SIZE = 8 @@ -504,6 +512,59 @@ def __init__(self, json_file, mpu=None, param_dict=None): self.global_rank = 0 self.world_size = 1 + # If elastic-mode enabled, update compute + update _param_dict + self.elasticity_enabled = elasticity_enabled(self._param_dict) + if self.elasticity_enabled: + logger.info("DeepSpeed elasticity support enabled") + final_batch_size, valid_gpus, micro_batch_size = compute_elastic_config( + ds_config=self._param_dict, + target_deepspeed_version=__version__, + world_size=self.world_size) + + elastic_dict = self._param_dict[ELASTICITY] + + # Ensure the resource scheduler saw the same elastic config we are using at runtime + ensure_immutable_elastic_config(runtime_elastic_config_dict=elastic_dict) + + ignore_non_elastic_batch_info = elastic_dict.get( + IGNORE_NON_ELASTIC_BATCH_INFO, + IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT) + + if not ignore_non_elastic_batch_info: + batch_params = [ + TRAIN_BATCH_SIZE, + TRAIN_MICRO_BATCH_SIZE_PER_GPU, + GRADIENT_ACCUMULATION_STEPS + ] + if any(map(lambda t: t in self._param_dict, batch_params)): + raise ElasticityConfigError("One or more batch related parameters were found in your " \ + f"ds_config ({TRAIN_BATCH_SIZE}, {TRAIN_MICRO_BATCH_SIZE_PER_GPU}, and/or " \ + f"{GRADIENT_ACCUMULATION_STEPS}). These parameters *will not be used* since " \ + "elastic training is enabled, which takes control of these parameters. " \ + "If you want to supress this error (the parameters will be silently ignored) " \ + f"please set {IGNORE_NON_ELASTIC_BATCH_INFO}':true in your elasticity config.") + + # micro_bsz * world_size * gas = total_batch_size + # gas = total_batch_size // (micro_bsz * world_size) + gradient_accu_steps = final_batch_size // (micro_batch_size * + self.world_size) + + if TRAIN_BATCH_SIZE in self._param_dict: + logger.warning("[Elasticity] overriding training_batch_size: " \ + f"{self._param_dict[TRAIN_BATCH_SIZE]} -> {final_batch_size}") + if TRAIN_MICRO_BATCH_SIZE_PER_GPU in self._param_dict: + logger.warning("[Elasticity] overriding train_micro_batch_size_per_gpu: " \ + f"{self._param_dict[TRAIN_MICRO_BATCH_SIZE_PER_GPU]} -> {micro_batch_size}") + if GRADIENT_ACCUMULATION_STEPS in self._param_dict: + logger.warning("[Elasticity] overriding gradient_accumulation_steps: "\ + f"{self._param_dict[GRADIENT_ACCUMULATION_STEPS]} -> {gradient_accu_steps}") + + logger.info(f"[Elasticity] valid GPU counts: {valid_gpus}") + + self._param_dict[TRAIN_BATCH_SIZE] = final_batch_size + self._param_dict[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = micro_batch_size + self._param_dict[GRADIENT_ACCUMULATION_STEPS] = gradient_accu_steps + self._initialize_params(self._param_dict) self._configure_train_batch_size() self._do_sanity_check() diff --git a/deepspeed/runtime/config_utils.py b/deepspeed/runtime/config_utils.py index 38fdb647f61d..37f35692369b 100755 --- a/deepspeed/runtime/config_utils.py +++ b/deepspeed/runtime/config_utils.py @@ -13,6 +13,10 @@ def get_scalar_param(param_dict, param_name, param_default_value): return param_dict.get(param_name, param_default_value) +def get_list_param(param_dict, param_name, param_default_value): + return param_dict.get(param_name, param_default_value) + + def dict_raise_error_on_duplicate_keys(ordered_pairs): """Reject duplicate keys.""" d = dict((k, v) for k, v in ordered_pairs) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 49e1bedd3cfc..8f86469e1073 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -137,6 +137,10 @@ def __init__(self, self._configure_with_arguments(args, mpu) self._do_sanity_check() + if mpu is not None: + assert not self.elasticity_enabled(), "Elasticity is not currently supported" \ + " with model parallelism." + self._set_distributed_vars() if self.tensorboard_enabled() and self.global_rank == 0: @@ -194,6 +198,22 @@ def __init__(self, self.flatten = util_ops.flatten self.unflatten = util_ops.unflatten + def get_batch_info(self): + """ Get all training batch related settings. + + Returns: + train_batch_size (int): The effective training batch size. This is the amount of data + samples that leads to one step of model update. + train_micro_batch_size_per_gpu (int): Batch size to be processed by one GPU in one + step (without gradient accumulation). + gradient_accumulation_steps (int): Number of training steps to accumulate gradients + before averaging and applying them. + """ + return self.train_batch_size, self.train_micro_batch_size_per_gpu, self.gradient_accumulation_steps + + def elasticity_enabled(self): + return self._config.elasticity_enabled + def pld_enabled(self): return self._config.pld_enabled @@ -1224,10 +1244,13 @@ def load_checkpoint(self, if tag is None: latest_path = os.path.join(load_dir, 'latest') - assert os.path.isfile(latest_path), f"Unable to find latest file at {latest_path}, if trying to load latest " \ - "checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint." - with open(latest_path, 'r') as fd: - tag = fd.read().strip() + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + logger.warning(f"Unable to find latest file at {latest_path}, if trying to load latest " \ + "checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint.") + return None, None load_path, client_states = self._load_checkpoint(load_dir, tag, diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 5c5d896dfc0d..87cc64950006 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -54,6 +54,8 @@ def __init__(self, *super_args, **super_kwargs): # We schedule the all-reduces, so disable it in super().backward() self.enable_backward_allreduce = False + assert not self.elasticity_enabled(), "Elasticity is not currently supported" \ + " with pipeline parallelism." # pipeline step for logging self.log_batch_step_id = -1 diff --git a/op_builder/builder.py b/op_builder/builder.py index 1f350065b4f6..4bdb9e036708 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -33,7 +33,9 @@ def installed_cuda_version(): def get_default_compute_capatabilities(): compute_caps = DEFAULT_COMPUTE_CAPABILITIES - if installed_cuda_version()[0] >= 11: + import torch.utils.cpp_extension + if torch.utils.cpp_extension.CUDA_HOME is not None and installed_cuda_version( + )[0] >= 11: compute_caps += ";8.0;8.6" return compute_caps diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 5845cdff4452..9192befdd35c 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -3,3 +3,4 @@ torchvision>=0.4.0 tqdm tensorboardX==1.8 ninja +numpy diff --git a/setup.py b/setup.py index bf2ff9813537..19df040dcc88 100755 --- a/setup.py +++ b/setup.py @@ -184,7 +184,8 @@ def op_enabled(op_name): 'bin/deepspeed.pt', 'bin/ds', 'bin/ds_ssh', - 'bin/ds_report' + 'bin/ds_report', + 'bin/ds_elastic' ], classifiers=[ 'Programming Language :: Python :: 3.6', diff --git a/tests/unit/test_checkpointing.py b/tests/unit/test_checkpointing.py index 1fbcacfa2aa4..1cd817ebc561 100755 --- a/tests/unit/test_checkpointing.py +++ b/tests/unit/test_checkpointing.py @@ -757,7 +757,7 @@ def _helper(args, model, hidden_dim): model, _, _,_ = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters()) - with pytest.raises(AssertionError): - model.load_checkpoint(tmpdir) + # should be no-op, since latest doesn't exist + model.load_checkpoint(tmpdir) _helper(args=args, model=model, hidden_dim=hidden_dim) diff --git a/tests/unit/test_elastic.py b/tests/unit/test_elastic.py new file mode 100644 index 000000000000..339112b1bc93 --- /dev/null +++ b/tests/unit/test_elastic.py @@ -0,0 +1,241 @@ +import pytest +import deepspeed +from common import distributed_test +from deepspeed.git_version_info import version as ds_version +from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict + +base_ds_config = { + "elasticity": { + "enabled": True, + "max_train_batch_size": 10000, + "micro_batch_sizes": [8, + 12, + 16, + 17], + "min_gpus": 32, + "max_gpus": 1500, + "min_time": 20, + "version": 0.1 + } +} + + +def test_basic_10k(): + ds_config = base_ds_config.copy() + final_batch_size, valid_gpus = deepspeed.elasticity.compute_elastic_config( + ds_config=ds_config, + target_deepspeed_version=ds_version) + + for gpu_num in valid_gpus: + assert final_batch_size % gpu_num == 0, f"Batch {final_batch_size} is not divisible by GPU count {gpu_num}" + batch_per_gpu = final_batch_size // gpu_num + found_valid_mbsize = False + + for mb in ds_config['elasticity']['micro_batch_sizes']: + if batch_per_gpu % mb == 0: + found_valid_mb = True + break + assert found_valid_mb, "No valid mb found" + + assert len(valid_gpus) == 23 + assert final_batch_size == 9792 + + +def test_old_version(): + ds_config = base_ds_config.copy() + with pytest.raises(deepspeed.elasticity.config.ElasticityError): + final_batch_size, valid_gpus = deepspeed.elasticity.compute_elastic_config( + ds_config=ds_config, + target_deepspeed_version="0.2") + + +def test_disabled(): + ds_config = base_ds_config.copy() + ds_config['elasticity']['enabled'] = False + with pytest.raises(deepspeed.elasticity.config.ElasticityError): + final_batch_size, valid_gpus = deepspeed.elasticity.compute_elastic_config( + ds_config=ds_config, + target_deepspeed_version=ds_version) + + +def test_valid_world_size(): + ds_config = base_ds_config.copy() + final_batch_size, valid_gpus, mbsize = deepspeed.elasticity.compute_elastic_config( + ds_config=ds_config, + target_deepspeed_version=ds_version, + world_size=64) + assert mbsize == 17 + + +def test_invalid_world_size(): + ds_config = base_ds_config.copy() + with pytest.raises(deepspeed.elasticity.config.ElasticityIncompatibleWorldSize): + final_batch_size, valid_gpus, mbsize = deepspeed.elasticity.compute_elastic_config( + ds_config=ds_config, + target_deepspeed_version=ds_version, + world_size=128) + + +def test_future_elastic_version(): + ds_config = base_ds_config.copy() + ds_config['elasticity']['version'] = '0.2' + with pytest.raises(deepspeed.elasticity.config.ElasticityError): + deepspeed.elasticity.compute_elastic_config(ds_config=ds_config, + target_deepspeed_version=ds_version) + + +def test_missing_max_batch(): + ds_config = base_ds_config.copy() + del ds_config['elasticity']['max_train_batch_size'] + with pytest.raises(deepspeed.elasticity.config.ElasticityError): + deepspeed.elasticity.compute_elastic_config(ds_config=ds_config, + target_deepspeed_version=ds_version) + + +def test_missing_micro_batch(): + ds_config = base_ds_config.copy() + del ds_config['elasticity']['micro_batch_sizes'] + with pytest.raises(deepspeed.elasticity.config.ElasticityError): + deepspeed.elasticity.compute_elastic_config(ds_config=ds_config, + target_deepspeed_version=ds_version) + + +def test_empty_config(): + ds_config = {"elasticity": {"enabled": True}} + with pytest.raises(deepspeed.elasticity.config.ElasticityError): + deepspeed.elasticity.compute_elastic_config(ds_config=ds_config, + target_deepspeed_version=ds_version) + + +def test_proper_mbsz(): + ds_config = base_ds_config.copy() + ds_config["elasticity"]["max_train_batch_size"] = 32 + ds_config["elasticity"]["micro_batch_sizes"] = [1, 2, 3, 7] + ds_config["elasticity"]["min_gpus"] = 1 + final_batch_size, valid_gpus, mbsize = deepspeed.elasticity.compute_elastic_config( + ds_config=ds_config, + target_deepspeed_version=ds_version, + world_size=7) + assert mbsize == 3 + + +def test_non_elastic_batch_params(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "Lamb", + "params": { + "lr": 0.00015 + } + }, + "gradient_clipping": 1.0, + "elasticity": { + "enabled": True, + "max_train_batch_size": 4, + "micro_batch_sizes": [1, + 2, + 3, + 4], + "min_gpus": 1, + "max_gpus": 4, + "min_time": 20, + "version": 0.1 + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim, empty_grad=False) + + @distributed_test(world_size=[1, 2]) + def _test_elastic(args, model, hidden_dim): + with pytest.raises(deepspeed.elasticity.config.ElasticityError): + model, _, _,_ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + + _test_elastic(args=args, model=model, hidden_dim=hidden_dim) + + +def test_non_elastic_batch_params_w_override(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "Lamb", + "params": { + "lr": 0.00015 + } + }, + "gradient_clipping": 1.0, + "elasticity": { + "enabled": True, + "max_train_batch_size": 4, + "micro_batch_sizes": [1, + 2, + 3, + 4], + "min_gpus": 1, + "max_gpus": 4, + "min_time": 20, + "version": 0.1, + "ignore_non_elastic_batch_info": True + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim, empty_grad=False) + + @distributed_test(world_size=[1, 2]) + def _test_elastic(args, model, hidden_dim): + model, _, _,_ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + + _test_elastic(args=args, model=model, hidden_dim=hidden_dim) + + +def test_elastic_config_changed(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "Lamb", + "params": { + "lr": 0.00015 + } + }, + "gradient_clipping": 1.0, + "elasticity": { + "enabled": True, + "max_train_batch_size": 4, + "micro_batch_sizes": [1, + 2, + 3, + 4], + "min_gpus": 1, + "max_gpus": 4, + "min_time": 20, + "version": 0.1, + "ignore_non_elastic_batch_info": True + } + } + import json, os + scheduler_elastic_config = config_dict.copy() + scheduler_elastic_config["elasticity"]["max_train_batch_size"] = 27 + os.environ['DEEPSPEED_ELASTICITY_CONFIG'] = json.dumps(scheduler_elastic_config) + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim, empty_grad=False) + + @distributed_test(world_size=[1, 2]) + def _test_elastic(args, model, hidden_dim): + with pytest.raises(deepspeed.elasticity.config.ElasticityError): + model, _, _,_ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + + _test_elastic(args=args, model=model, hidden_dim=hidden_dim) diff --git a/version.txt b/version.txt index 667843220966..940ac09aa677 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.3.8 +0.3.9 From 24e07399fc3ee95417d7a582c3e8952e1ac7d567 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Mon, 4 Jan 2021 12:11:13 -0800 Subject: [PATCH 12/17] update SA comp check to fix torch-cpu issue (#631) --- op_builder/sparse_attn.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/op_builder/sparse_attn.py b/op_builder/sparse_attn.py index 4c716f859970..c3fa5624b25e 100644 --- a/op_builder/sparse_attn.py +++ b/op_builder/sparse_attn.py @@ -25,6 +25,18 @@ def is_compatible(self): command_status = list(map(self.command_exists, required_commands)) deps_compatible = all(command_status) + # torch-cpu will not have a cuda version + if torch.version.cuda is None: + cuda_compatible = False + self.warning(f"{self.NAME} cuda is not available from torch") + else: + major, minor = torch.version.cuda.split('.')[:2] + cuda_compatible = int(major) == 10 and int(minor) >= 1 + if not cuda_compatible: + self.warning( + f"{self.NAME} requires CUDA version 10.1+, does not currently support >=11 or <10.1" + ) + TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) torch_compatible = TORCH_MAJOR == 1 and TORCH_MINOR >= 5 @@ -33,4 +45,5 @@ def is_compatible(self): f'{self.NAME} requires a torch version >= 1.5 but detected {TORCH_MAJOR}.{TORCH_MINOR}' ) - return super().is_compatible() and deps_compatible and torch_compatible + return super().is_compatible( + ) and deps_compatible and torch_compatible and cuda_compatible From e6ac7311363da55cf22abfe413679dd34b5ad756 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Mon, 4 Jan 2021 15:55:41 -0800 Subject: [PATCH 13/17] Support initialization with dict configuration (#632) --- deepspeed/runtime/engine.py | 6 +++--- tests/unit/simple_model.py | 11 ++++++++--- tests/unit/test_fp16.py | 37 ++++++++++++++++++++++++++++++++++++- 3 files changed, 47 insertions(+), 7 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 8f86469e1073..a87a56cb5b9b 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -435,9 +435,9 @@ def _set_distributed_vars(self): # Configure based on command line arguments def _configure_with_arguments(self, args, mpu): self.local_rank = args.local_rank if hasattr(args, 'local_rank') else 0 - self._config = DeepSpeedConfig(args.deepspeed_config, - mpu, - param_dict=self.config_params) + config_file = args.deepspeed_config if hasattr(args, + 'deepspeed_config') else None + self._config = DeepSpeedConfig(config_file, mpu, param_dict=self.config_params) # Validate command line arguments def _do_args_sanity_check(self, args): diff --git a/tests/unit/simple_model.py b/tests/unit/simple_model.py index b0c268341224..17215cd323bb 100755 --- a/tests/unit/simple_model.py +++ b/tests/unit/simple_model.py @@ -161,12 +161,10 @@ def create_config_from_dict(tmpdir, config_dict): return config_path -def args_from_dict(tmpdir, config_dict): - config_path = create_config_from_dict(tmpdir, config_dict) +def create_deepspeed_args(): parser = argparse.ArgumentParser() args = parser.parse_args(args='') args.deepspeed = True - args.deepspeed_config = config_path if torch.distributed.is_initialized(): # We assume up to one full node executing unit tests assert torch.distributed.get_world_size() <= torch.cuda.device_count() @@ -174,3 +172,10 @@ def args_from_dict(tmpdir, config_dict): else: args.local_rank = 0 return args + + +def args_from_dict(tmpdir, config_dict): + args = create_deepspeed_args() + config_path = create_config_from_dict(tmpdir, config_dict) + args.deepspeed_config = config_path + return args diff --git a/tests/unit/test_fp16.py b/tests/unit/test_fp16.py index 30d53a00251f..eb1ecc86425d 100755 --- a/tests/unit/test_fp16.py +++ b/tests/unit/test_fp16.py @@ -6,7 +6,7 @@ import os from deepspeed.ops.adam import FusedAdam from common import distributed_test -from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict +from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict, create_deepspeed_args try: from apex import amp @@ -194,6 +194,41 @@ def _test_adamw_fp16_basic(args, model, hidden_dim): _test_adamw_fp16_basic(args=args, model=model, hidden_dim=hidden_dim) +def test_dict_config_adamw_fp16_basic(): + config_dict = { + "train_batch_size": 1, + "steps_per_print": 1, + "fp16": { + "enabled": True + } + } + args = create_deepspeed_args() + hidden_dim = 10 + + model = SimpleModel(hidden_dim, empty_grad=False) + + @distributed_test(world_size=[1]) + def _test_adamw_fp16_basic(args, model, hidden_dim, config_dict): + optimizer = torch.optim.AdamW(params=model.parameters()) + model, _, _,_ = deepspeed.initialize(args=args, + model=model, + optimizer=optimizer, + config_params=config_dict) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + _test_adamw_fp16_basic(args=args, + model=model, + hidden_dim=hidden_dim, + config_dict=config_dict) + + def test_adamw_fp16_empty_grad(tmpdir): config_dict = { "train_batch_size": 1, From a9a83a6fcfcf654d75017453fbb3a476000180ce Mon Sep 17 00:00:00 2001 From: gcooper-isi <42359489+gcooper-isi@users.noreply.github.com> Date: Tue, 5 Jan 2021 13:14:29 -0500 Subject: [PATCH 14/17] Allow DeepSpeed models to be initialized with optimizer=None (#469) Allow DeepSpeed models to be initialized with optimizer=None Co-authored-by: Shaden Smith --- deepspeed/runtime/engine.py | 9 ++++----- tests/unit/test_config.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index a87a56cb5b9b..99db78ec6dc5 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -466,10 +466,9 @@ def _is_supported_optimizer(self, optimizer_name): # Validate configuration based on command line arguments def _do_sanity_check(self): if not self.client_optimizer: - assert self._is_supported_optimizer(self.optimizer_name()), \ - '{} is not a supported DeepSpeed Optimizer'.format(self.optimizer_name()) - assert self.client_model_parameters, \ - 'DeepSpeed {} optimizer requires parameters in initialize() call'.format(self.optimizer_name()) + if self.optimizer_name() is not None: + assert self._is_supported_optimizer(self.optimizer_name()), \ + '{} is not a supported DeepSpeed Optimizer'.format(self.optimizer_name()) if self.optimizer_name() == LAMB_OPTIMIZER: assert self.dynamic_loss_scale(), \ @@ -1289,7 +1288,7 @@ def _load_checkpoint(self, self.load_module_state_dict(state_dict=checkpoint['module'], strict=load_module_strict) - if not self.zero_optimization(): + if self.optimizer is not None and not self.zero_optimization(): if self.fp16_enabled(): self.optimizer.load_state_dict( checkpoint['optimizer'], diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index e5fe75b281e0..4cabefe71a33 100755 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -195,3 +195,34 @@ def _test_dist_init_true(args, model, hidden_dim): model.step() _test_dist_init_true(args=args, model=model, hidden_dim=hidden_dim) + + +def test_init_no_optimizer(tmpdir): + + config_dict = {"train_batch_size": 1, "fp16": {"enabled": True}} + config_path = create_config_from_dict(tmpdir, config_dict) + + @distributed_test(world_size=1) + def _helper(): + parser = argparse.ArgumentParser() + args = parser.parse_args(args='') + args.deepscale_config = config_path + args.local_rank = 0 + + hidden_dim = 10 + + model = SimpleModel(hidden_dim=hidden_dim) + + model, _, _, _ = deepspeed.initialize(args=args, model=model) + data_loader = random_dataloader(model=model, + total_samples=5, + hidden_dim=hidden_dim, + device=model.device) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + with pytest.raises(AssertionError): + model.backward(loss) + with pytest.raises(AssertionError): + model.step() + + _helper() From d38ad6a17164b9bf07477ceb17ca5c7f09657720 Mon Sep 17 00:00:00 2001 From: Ammar Ahmad Awan Date: Tue, 5 Jan 2021 10:31:49 -0800 Subject: [PATCH 15/17] change dist to torch.distributed to fix bug in assert. (#638) --- deepspeed/utils/distributed.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepspeed/utils/distributed.py b/deepspeed/utils/distributed.py index e70f00b440bb..c9722af21c24 100644 --- a/deepspeed/utils/distributed.py +++ b/deepspeed/utils/distributed.py @@ -77,9 +77,9 @@ def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True) os.environ['MASTER_PORT'])) if torch.distributed.is_initialized(): - assert dist.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, dist.get_rank()) - assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format( - world_size, dist.get_world_size()) + assert torch.distributed.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, dist.get_rank()) + assert torch.distributed.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format( + world_size, torch.distributed.get_world_size()) def in_aml(): From 46d2e2872b64ebccb8bf4eb5c8a3a55f9adaaa6c Mon Sep 17 00:00:00 2001 From: brett koonce Date: Tue, 5 Jan 2021 12:32:44 -0600 Subject: [PATCH 16/17] docs: minor spelling tweaks (#623) Co-authored-by: Jeff Rasley --- docs/_pages/features.md | 8 ++++---- docs/_posts/2020-09-08-sparse-attention-news.md | 2 +- docs/_tutorials/onebit-adam.md | 4 ++-- docs/_tutorials/pipeline.md | 2 +- docs/_tutorials/progressive_layer_dropping.md | 2 +- docs/_tutorials/zero.md | 2 +- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/_pages/features.md b/docs/_pages/features.md index 2074bb3e3b0f..3ad1c8e91984 100755 --- a/docs/_pages/features.md +++ b/docs/_pages/features.md @@ -79,7 +79,7 @@ DeepSpeed. ### Optimizer State and Gradient Partitioning Optimizer State and Gradient Partitioning in ZeRO reduces the memory consumption of the -model states (optimizer states, gradients and parmaeters) by 8x compared to standard +model states (optimizer states, gradients and parameters) by 8x compared to standard data parallelism by partitioning these states across data parallel process instead of replicating them. @@ -150,8 +150,8 @@ Please see the [core API doc](https://deepspeed.readthedocs.io/) for more detail ### Activation Checkpointing API -DeepSpeed's Activation Checkpoinitng API supports activation checkpoint partitioning, -cpu checkpoiniting, and contiguous memory optimizations, while also allowing layerwise +DeepSpeed's Activation Checkpointing API supports activation checkpoint partitioning, +cpu checkpointing, and contiguous memory optimizations, while also allowing layerwise profiling. Please see the [core API doc](https://deepspeed.readthedocs.io/) for more details. @@ -190,7 +190,7 @@ NVIDIA, or any training optimizer that extends torch's `torch.optim.Optimizer` c We introduce an efficient implementation of Adam optimizer on CPU that improves the parameter-update performance by nearly an order of magnitude. We use the AVX SIMD instructions on Intel-x86 architecture for the CPU-Adam implementation. We support both AVX-512 and AVX-2 instruction sets. DeepSpeed uses -AVX-2 by defualt which can be switched to AVX-512 by setting the build flag, `DS_BUILD_AVX512` to 1 when +AVX-2 by default which can be switched to AVX-512 by setting the build flag, `DS_BUILD_AVX512` to 1 when installing DeepSpeed. Using AVX-512, we observe 5.1x to 6.5x speedups considering the model-size between 1 to 10 billion parameters with respect to torch-adam. diff --git a/docs/_posts/2020-09-08-sparse-attention-news.md b/docs/_posts/2020-09-08-sparse-attention-news.md index ca133df61123..6f235818c33f 100644 --- a/docs/_posts/2020-09-08-sparse-attention-news.md +++ b/docs/_posts/2020-09-08-sparse-attention-news.md @@ -12,4 +12,4 @@ DeepSpeed offers sparse attention kernels, an instrumental technology to support * Brief overview, see our [press release]({{ site.press_release_v3 }}). * Detailed technology deep dive, see our [blog post](https://www.deepspeed.ai/news/2020/09/08/sparse-attention.html). * Tutorial on how to use sparse attention, see our [Sparse attention tutorial](https://www.deepspeed.ai/tutorials/sparse-attention/). -* The source code for our sparse attention kernels can be found in the [DeepSpeed repo](https://github.com/microsoft/deepspeed) and BERT pre-training code useing sparse attention can be found in the [DeepSpeedExamples repo](https://github.com/microsoft/deepspeedexamples). +* The source code for our sparse attention kernels can be found in the [DeepSpeed repo](https://github.com/microsoft/deepspeed) and BERT pre-training code using sparse attention can be found in the [DeepSpeedExamples repo](https://github.com/microsoft/deepspeedexamples). diff --git a/docs/_tutorials/onebit-adam.md b/docs/_tutorials/onebit-adam.md index 4039589b2ed3..8871a5dd0e28 100644 --- a/docs/_tutorials/onebit-adam.md +++ b/docs/_tutorials/onebit-adam.md @@ -120,7 +120,7 @@ Alternatively, we show how the standard `mpirun` launcher can be used for launch mpirun -np [#processes] -ppn [#GPUs on each node] -hostfile [hostfile] [MPI flags] bash run_squad_mpi_onebitadam.sh ``` -For example, in order to use 32 GPUs (4GPUs/node, 8 nodes in total), with the support of InfiniBand, you can use the `mpirun` launcher packaged with the MVAPICH2 library. Please run the folowing command: +For example, in order to use 32 GPUs (4GPUs/node, 8 nodes in total), with the support of InfiniBand, you can use the `mpirun` launcher packaged with the MVAPICH2 library. Please run the following command: ```shell mpirun -np 32 -ppn 4 -hostfile hosts -env MV2_USE_CUDA=1 -env MV2_SUPPORT_DL=1 -env MV2_ENABLE_AFFINITY=0 -env MV2_SMP_USE_CMA=0 bash run_squad_mpi_onebitadam.sh @@ -166,7 +166,7 @@ We fixed the learning rate to 3e-5. The table below shows the F1 and the EM scor ***Training Speed and Scalability:*** -1-bit Adam enables up to 2.7x overall speedup in training speed for SQuAD fine-tuning. This is made possible by up to 6.2x faster througput during the compressed stage of the algorithm as shown in Figure 1. +1-bit Adam enables up to 2.7x overall speedup in training speed for SQuAD fine-tuning. This is made possible by up to 6.2x faster throughput during the compressed stage of the algorithm as shown in Figure 1. ![SQuAD Finetuning](/assets/images/squad-scaling.png){: .align-center} diff --git a/docs/_tutorials/pipeline.md b/docs/_tutorials/pipeline.md index 64d7528ee6fb..e7730ebe2661 100644 --- a/docs/_tutorials/pipeline.md +++ b/docs/_tutorials/pipeline.md @@ -75,7 +75,7 @@ net = PipelineModule(layers=net, num_stages=2) ``` `PipelineModule` uses its `layers` argument as the sequence of layers that comprise the model. After initialization, `net` is divided into two pipeline -stages and its layers moved to the correpsonding GPUs. If more than two GPUs +stages and its layers moved to the corresponding GPUs. If more than two GPUs are present, DeepSpeed will also use hybrid data parallelism. **Note:** The total number of GPUs must be divisible by the number of pipeline diff --git a/docs/_tutorials/progressive_layer_dropping.md b/docs/_tutorials/progressive_layer_dropping.md index 4958717f8d09..8a447e97c945 100755 --- a/docs/_tutorials/progressive_layer_dropping.md +++ b/docs/_tutorials/progressive_layer_dropping.md @@ -95,7 +95,7 @@ Note that the above configuration assumes training on 64 X 32GB V100 GPUs. Each Table 1. Pre-training hyperparameters -**Note:** DeepSpeed now supports PreLayerNorm as the default way for training BERT, because of its ability to avoid vanishing gradient, stablize optimization, and performance gains, as described in our fastest BERT training [blog post](https://www.deepspeed.ai/news/2020/05/27/fastest-bert-training.html). We therefore support the switchable Transformer block directly on the the BERT with PreLayerNorm. The implementation can be found at "example\bing_bert\nvidia\modelingpreln_layerdrop.py". +**Note:** DeepSpeed now supports PreLayerNorm as the default way for training BERT, because of its ability to avoid vanishing gradient, stabilize optimization, and performance gains, as described in our fastest BERT training [blog post](https://www.deepspeed.ai/news/2020/05/27/fastest-bert-training.html). We therefore support the switchable Transformer block directly on the the BERT with PreLayerNorm. The implementation can be found at "example\bing_bert\nvidia\modelingpreln_layerdrop.py". ## Fine-tuning with DeepSpeed on GLUE Tasks diff --git a/docs/_tutorials/zero.md b/docs/_tutorials/zero.md index 356f2369e54a..45d663a52563 100644 --- a/docs/_tutorials/zero.md +++ b/docs/_tutorials/zero.md @@ -79,7 +79,7 @@ Next, we need to update the DeepSpeed json configuration, as shown below, to ena } ``` -In the above changes, we have set the _stage_ field to 2, and configured other optimization knobs that are available in ZeRO stage 2. For example, we have enabled _contiguous_gradients_ to reduce memory fragmenation during backward pass. A full description of these optimization knobs is available [here](/docs/config-json/#zero-optimizations-for-fp16-training). With these changes, we can now launch the training run. +In the above changes, we have set the _stage_ field to 2, and configured other optimization knobs that are available in ZeRO stage 2. For example, we have enabled _contiguous_gradients_ to reduce memory fragmentation during backward pass. A full description of these optimization knobs is available [here](/docs/config-json/#zero-optimizations-for-fp16-training). With these changes, we can now launch the training run. Here is a screenshot of the training log: From 5ab12795958dff0b56d3faf7421abe9cad7bcba1 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Tue, 5 Jan 2021 13:55:21 -0800 Subject: [PATCH 17/17] Fix docstring format (#640) --- deepspeed/runtime/engine.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 99db78ec6dc5..27ce1c7439fc 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1236,9 +1236,12 @@ def load_checkpoint(self, load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match. load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM's momentum and variance load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint. - Return: - load_path: Path of the loaded checkpoint. None if loading the checkpoint failed - client_state: State dictionary used for loading required training states in the client code. + Returns: + A tuple of ``load_path`` and ``client_state``. + + *``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed. + + *``client_state``: State dictionary used for loading required training states in the client code. """ if tag is None: