diff --git a/.github/workflows/nv-a6000.yml b/.github/workflows/nv-a6000.yml index a5b6cd5c28e5..de325cb20826 100644 --- a/.github/workflows/nv-a6000.yml +++ b/.github/workflows/nv-a6000.yml @@ -3,9 +3,9 @@ name: nv-a6000 on: pull_request: paths: - - "deepspeed/inference/v2/**" - - "tests/unit/inference/v2/**" - - ".github/workflows/nv-a6000.yml" + - 'deepspeed/inference/v2/**' + - 'tests/unit/inference/v2/**' + - '.github/workflows/nv-a6000.yml' workflow_dispatch: concurrency: diff --git a/.github/workflows/nv-accelerate-v100.yml b/.github/workflows/nv-accelerate-v100.yml index 0f6491e08336..d8a03ff34f78 100644 --- a/.github/workflows/nv-accelerate-v100.yml +++ b/.github/workflows/nv-accelerate-v100.yml @@ -6,7 +6,7 @@ on: - 'docs/**' - 'blogs/**' - 'deepspeed/inference/v2/**' - - "tests/unit/inference/v2/**" + - 'tests/unit/inference/v2/**' merge_group: branches: [ master ] schedule: diff --git a/.github/workflows/nv-inference.yml b/.github/workflows/nv-inference.yml index f20b4496b6df..156fa0228d0b 100644 --- a/.github/workflows/nv-inference.yml +++ b/.github/workflows/nv-inference.yml @@ -6,7 +6,7 @@ on: - 'docs/**' - 'blogs/**' - 'deepspeed/inference/v2/**' - - "tests/unit/inference/v2/**" + - 'tests/unit/inference/v2/**' merge_group: branches: [ master ] schedule: diff --git a/.github/workflows/nv-lightning-v100.yml b/.github/workflows/nv-lightning-v100.yml index d25d40aef967..ffcecb1e0d36 100644 --- a/.github/workflows/nv-lightning-v100.yml +++ b/.github/workflows/nv-lightning-v100.yml @@ -6,7 +6,7 @@ on: - 'docs/**' - 'blogs/**' - 'deepspeed/inference/v2/**' - - "tests/unit/inference/v2/**" + - 'tests/unit/inference/v2/**' merge_group: branches: [ master ] schedule: diff --git a/.github/workflows/nv-megatron.yml b/.github/workflows/nv-megatron.yml index 3a3b70dcd17d..0c3fc4c5ef5e 100644 --- a/.github/workflows/nv-megatron.yml +++ b/.github/workflows/nv-megatron.yml @@ -6,7 +6,7 @@ on: - 'docs/**' - 'blogs/**' - 'deepspeed/inference/v2/**' - - "tests/unit/inference/v2/**" + - 'tests/unit/inference/v2/**' merge_group: branches: [ master ] schedule: diff --git a/.github/workflows/nv-pre-compile-ops.yml b/.github/workflows/nv-pre-compile-ops.yml index 839312190d22..505e73fee156 100644 --- a/.github/workflows/nv-pre-compile-ops.yml +++ b/.github/workflows/nv-pre-compile-ops.yml @@ -8,7 +8,7 @@ on: - 'docs/**' - 'blogs/**' - 'deepspeed/inference/v2/**' - - "tests/unit/inference/v2/**" + - 'tests/unit/inference/v2/**' merge_group: branches: [ master ] schedule: @@ -19,7 +19,7 @@ concurrency: cancel-in-progress: true jobs: - build-ops: + unit-tests: runs-on: ubuntu-20.04 container: image: deepspeed/gh-builder:ubuntu1804-py38-torch1131-cu116 diff --git a/.github/workflows/nv-torch-latest-cpu.yml b/.github/workflows/nv-torch-latest-cpu.yml index 9ca1529d9018..375b984134cb 100644 --- a/.github/workflows/nv-torch-latest-cpu.yml +++ b/.github/workflows/nv-torch-latest-cpu.yml @@ -6,7 +6,7 @@ on: - 'docs/**' - 'blogs/**' - 'deepspeed/inference/v2/**' - - "tests/unit/inference/v2/**" + - 'tests/unit/inference/v2/**' merge_group: branches: [ master ] schedule: diff --git a/.github/workflows/nv-torch-latest-v100.yml b/.github/workflows/nv-torch-latest-v100.yml index 8813a4bb2c4f..e4e61acd8143 100644 --- a/.github/workflows/nv-torch-latest-v100.yml +++ b/.github/workflows/nv-torch-latest-v100.yml @@ -6,7 +6,7 @@ on: - 'docs/**' - 'blogs/**' - 'deepspeed/inference/v2/**' - - "tests/unit/inference/v2/**" + - 'tests/unit/inference/v2/**' merge_group: branches: [ master ] schedule: diff --git a/.github/workflows/nv-transformers-v100.yml b/.github/workflows/nv-transformers-v100.yml index 7753133f2886..4fbc42abec5f 100644 --- a/.github/workflows/nv-transformers-v100.yml +++ b/.github/workflows/nv-transformers-v100.yml @@ -6,7 +6,7 @@ on: - 'docs/**' - 'blogs/**' - 'deepspeed/inference/v2/**' - - "tests/unit/inference/v2/**" + - 'tests/unit/inference/v2/**' merge_group: branches: [ master ] schedule: diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index fdbbd33c07a2..843e55ac3d20 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -63,7 +63,7 @@ def random(self): return torch.random def set_rng_state(self, new_state, device_index=None): - if device_index == None: + if device_index is None: return torch.set_rng_state(new_state) return torch.set_rng_state(new_state, device_index) @@ -253,7 +253,7 @@ def on_accelerator(self, tensor): # create an instance of op builder and return, name specified by class_name def create_op_builder(self, op_name): builder_class = self.get_op_builder(op_name) - if builder_class != None: + if builder_class is not None: return builder_class() return None diff --git a/accelerator/cuda_accelerator.py b/accelerator/cuda_accelerator.py index 2786b425ca7f..521cba6a5fdf 100644 --- a/accelerator/cuda_accelerator.py +++ b/accelerator/cuda_accelerator.py @@ -44,7 +44,7 @@ def is_synchronized_device(self): # Device APIs def device_name(self, device_index=None): - if device_index == None: + if device_index is None: return 'cuda' return 'cuda:{}'.format(device_index) @@ -280,7 +280,7 @@ def op_builder_dir(self): class_dict = None def _lazy_init_class_dict(self): - if self.class_dict != None: + if self.class_dict is not None: return else: self.class_dict = {} diff --git a/accelerator/mps_accelerator.py b/accelerator/mps_accelerator.py index 77595f6b636c..300b224a7af8 100644 --- a/accelerator/mps_accelerator.py +++ b/accelerator/mps_accelerator.py @@ -26,7 +26,7 @@ def is_synchronized_device(self): # Device APIs def device_name(self, device_index=None): - if device_index == None: + if device_index is None: return "mps" return "mps:{}".format(device_index) @@ -221,7 +221,7 @@ def op_builder_dir(self): # create an instance of op builder, specified by class_name def create_op_builder(self, op_name): builder_class = self.get_op_builder(op_name) - if builder_class != None: + if builder_class is not None: return builder_class() return None diff --git a/accelerator/npu_accelerator.py b/accelerator/npu_accelerator.py index 7ebbd320bb15..951bc26c197c 100644 --- a/accelerator/npu_accelerator.py +++ b/accelerator/npu_accelerator.py @@ -30,7 +30,7 @@ def is_synchronized_device(self): # Device APIs def device_name(self, device_index=None): - if device_index == None: + if device_index is None: return 'npu' return 'npu:{}'.format(device_index) diff --git a/accelerator/real_accelerator.py b/accelerator/real_accelerator.py index 49133489b051..0389252e8894 100644 --- a/accelerator/real_accelerator.py +++ b/accelerator/real_accelerator.py @@ -45,7 +45,7 @@ def _validate_accelerator(accel_obj): def is_current_accelerator_supported(): - return get_accelerator() in SUPPORTED_ACCELERATOR_LIST + return get_accelerator().device_name() in SUPPORTED_ACCELERATOR_LIST def get_accelerator(): diff --git a/blogs/deepspeed-fastgen/README.md b/blogs/deepspeed-fastgen/README.md index c309a9def53f..4081c780e09a 100644 --- a/blogs/deepspeed-fastgen/README.md +++ b/blogs/deepspeed-fastgen/README.md @@ -228,6 +228,7 @@ We currently support the following model architectures in this alpha release of * [LLaMA](https://huggingface.co/models?other=llama) and [LLaMA-2](https://huggingface.co/models?other=llama-2) * [Mistral](https://huggingface.co/models?other=mistral) * [OPT](https://huggingface.co/models?other=opt) +* [Falcon](https://huggingface.co/models?other=falcon) All current models leverage [HuggingFace](https://github.com/huggingface) APIs in our backend to provide both the model weights and the model's corresponding tokenizer. diff --git a/deepspeed/inference/quantization/layers.py b/deepspeed/inference/quantization/layers.py index c90354aca90f..e9a7e5629f1b 100644 --- a/deepspeed/inference/quantization/layers.py +++ b/deepspeed/inference/quantization/layers.py @@ -86,7 +86,7 @@ def __init__(self, config: Dict, pre_quant_layer: nn.Embedding) -> None: device=pre_quant_layer.weight.device, dtype=pre_quant_layer.weight.dtype) - assert pre_quant_layer.max_norm == None, 'Not supported' + assert pre_quant_layer.max_norm is None, 'Not supported' assert pre_quant_layer.norm_type == 2, 'Not supported' assert pre_quant_layer.scale_grad_by_freq == False, 'Not supported' assert pre_quant_layer.sparse == False, 'Not supported' diff --git a/deepspeed/inference/v2/engine_factory.py b/deepspeed/inference/v2/engine_factory.py index ecca9f3c1b34..9558125ff934 100644 --- a/deepspeed/inference/v2/engine_factory.py +++ b/deepspeed/inference/v2/engine_factory.py @@ -17,6 +17,7 @@ OPTPolicy, Llama2Policy, MistralPolicy, + FalconPolicy, ) from .model_implementations.inference_policy_base import POLICIES, InferenceV2Policy from .model_implementations.flat_model_helpers import make_metadata_filename, ModelMetadata @@ -104,6 +105,8 @@ def build_hf_engine(path: str, assert version.parse(transformers.__version__) >= version.parse("4.34.0"), \ f"Mistral requires transformers >= 4.34.0, you have version {transformers.__version__}" policy = MistralPolicy(model_config, checkpoint_engine=checkpoint_engine) + elif model_config.model_type == "falcon": + policy = FalconPolicy(model_config, checkpoint_engine=checkpoint_engine) else: raise ValueError(f"Unsupported model type {model_config.model_type}") diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cu b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cu index 63ea5bc88bab..980334f02b0b 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cu +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cu @@ -227,6 +227,16 @@ void launch_kv_rotary_kernel(T* kv_cache, DISPATCH_KV_ROTARY_IMPL(5, 128) DISPATCH_KV_ROTARY_IMPL(8, 64) DISPATCH_KV_ROTARY_IMPL(8, 128) + DISPATCH_KV_ROTARY_IMPL(16, 64) + DISPATCH_KV_ROTARY_IMPL(16, 128) + DISPATCH_KV_ROTARY_IMPL(29, 64) + DISPATCH_KV_ROTARY_IMPL(29, 128) + DISPATCH_KV_ROTARY_IMPL(35, 64) + DISPATCH_KV_ROTARY_IMPL(35, 128) + DISPATCH_KV_ROTARY_IMPL(36, 64) + DISPATCH_KV_ROTARY_IMPL(36, 128) + DISPATCH_KV_ROTARY_IMPL(71, 64) + DISPATCH_KV_ROTARY_IMPL(71, 128) } #define INSTANTIATE_KV_ROTARY_KERNEL(TYPE) \ diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py index 630d58d90a23..50d9aca061f3 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py @@ -19,7 +19,7 @@ class BlockedRotaryEmbeddings(DSKernelBase): supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] supported_head_sizes = [64, 128] - supported_q_ratios = [1, 2, 4, 5, 8] + supported_q_ratios = [1, 2, 4, 5, 8, 16, 29, 35, 36, 71] def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype) -> None: """ diff --git a/deepspeed/inference/v2/model_implementations/__init__.py b/deepspeed/inference/v2/model_implementations/__init__.py index dae406271245..481be2e5940e 100644 --- a/deepspeed/inference/v2/model_implementations/__init__.py +++ b/deepspeed/inference/v2/model_implementations/__init__.py @@ -12,3 +12,4 @@ from .llama_v2 import * from .opt import * from .mistral import * +from .falcon import * diff --git a/deepspeed/inference/v2/model_implementations/falcon/__init__.py b/deepspeed/inference/v2/model_implementations/falcon/__init__.py new file mode 100644 index 000000000000..ff66879b44be --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/falcon/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .falcon_policy import FalconPolicy diff --git a/deepspeed/inference/v2/model_implementations/falcon/falcon_containers.py b/deepspeed/inference/v2/model_implementations/falcon/falcon_containers.py new file mode 100644 index 000000000000..f3cbe6609cdd --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/falcon/falcon_containers.py @@ -0,0 +1,129 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Create a container object to save model-specific tensors using the policy file above. + +from ...model_implementations.common_parameters import * +from ...model_implementations.layer_container_base import LayerContainer +''' + # HF Falcon 7b model looks like this: + +FalconForCausalLM( + (transformer): FalconModel( + (word_embeddings): Embedding(65024, 4544) + (h): ModuleList( + (0-31): 32 x FalconDecoderLayer( + (self_attention): FalconAttention( + (maybe_rotary): FalconRotaryEmbedding() + (query_key_value): FalconLinear(in_features=4544, out_features=4672, bias=False) + (dense): FalconLinear(in_features=4544, out_features=4544, bias=False) + (attention_dropout): Dropout(p=0.0, inplace=False) + ) + (mlp): FalconMLP( + (dense_h_to_4h): FalconLinear(in_features=4544, out_features=18176, bias=False) + (act): GELU(approximate='none') + (dense_4h_to_h): FalconLinear(in_features=18176, out_features=4544, bias=False) + ) + (input_layernorm): LayerNorm((4544,), eps=1e-05, elementwise_affine=True) + ) + ) + (ln_f): LayerNorm((4544,), eps=1e-05, elementwise_affine=True) + ) + (lm_head): Linear(in_features=4544, out_features=65024, bias=False) +) +''' + + +class FalconTransformerContainer(LayerContainer): + """ + Transformer layer container for the Falcon model. + """ + qkv_w: FusedQKVParameter + attn_out_w: AttentionOutputParameter + mlp_1_w: MLP1Parameter + mlp_2_w: MLP2Parameter + ln_attn_gamma: NormParameter + ln_attn_beta: NormParameter + + PARAM_MAPPING = { + "self_attention.query_key_value.weight": "qkv_w.params", + "self_attention.dense.weight": "attn_out_w.params", + "mlp.dense_h_to_4h.weight": "mlp_1_w.params", + "mlp.dense_4h_to_h.weight": "mlp_2_w.params", + "input_layernorm.weight": "ln_attn_gamma.params", + "input_layernorm.bias": "ln_attn_beta.params", + } + + +class FalconNonTransformerContainer(LayerContainer): + """ + Non-Transformer layer container for the Falcon model. + """ + word_emb: EmbeddingParameter + word_unembed: UnembedParameter + final_norm_gamma: NormParameter + final_norm_beta: NormParameter + + PARAM_MAPPING = { + "transformer.word_embeddings.weight": "word_emb.params", + "transformer.ln_f.weight": "final_norm_gamma.params", + "transformer.ln_f.bias": "final_norm_beta.params", + "lm_head.weight": "word_unembed.params", + } + + +''' + # HF Falcon 40b model looks like this: + + FalconForCausalLM( + (transformer): FalconModel( + (word_embeddings): Embedding(65024, 8192) + (h): ModuleList( + (0-59): 60 x FalconDecoderLayer( + (self_attention): FalconAttention( + (maybe_rotary): FalconRotaryEmbedding() + (query_key_value): FalconLinear(in_features=8192, out_features=9216, bias=False) + (dense): FalconLinear(in_features=8192, out_features=8192, bias=False) + (attention_dropout): Dropout(p=0.0, inplace=False) + ) + (mlp): FalconMLP( + (dense_h_to_4h): FalconLinear(in_features=8192, out_features=32768, bias=False) + (act): GELU(approximate='none') + (dense_4h_to_h): FalconLinear(in_features=32768, out_features=8192, bias=False) + ) + (ln_attn): LayerNorm((8192,), eps=1e-05, elementwise_affine=True) + (ln_mlp): LayerNorm((8192,), eps=1e-05, elementwise_affine=True) + ) + ) + (ln_f): LayerNorm((8192,), eps=1e-05, elementwise_affine=True) + ) + (lm_head): Linear(in_features=8192, out_features=65024, bias=False) +) +''' + + +class FalconNewArchTransformerContainer(LayerContainer): + """ + Transformer layer container for the Falcon model. + """ + qkv_w: GQAMegatronQKVParameter + attn_out_w: AttentionOutputParameter + mlp_1_w: MLP1Parameter + mlp_2_w: MLP2Parameter + ln_attn_gamma: NormParameter + ln_attn_beta: NormParameter + ln_mlp_gamma: NormParameter + ln_mlp_beta: NormParameter + + PARAM_MAPPING = { + "self_attention.query_key_value.weight": "qkv_w.params", + "self_attention.dense.weight": "attn_out_w.params", + "mlp.dense_h_to_4h.weight": "mlp_1_w.params", + "mlp.dense_4h_to_h.weight": "mlp_2_w.params", + "ln_attn.weight": "ln_attn_gamma.params", + "ln_attn.bias": "ln_attn_beta.params", + "ln_mlp.weight": "ln_mlp_gamma.params", + "ln_mlp.bias": "ln_mlp_beta.params", + } diff --git a/deepspeed/inference/v2/model_implementations/falcon/falcon_model.py b/deepspeed/inference/v2/model_implementations/falcon/falcon_model.py new file mode 100644 index 000000000000..a00f754744a4 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/falcon/falcon_model.py @@ -0,0 +1,206 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Optional, Tuple + +import torch + +import deepspeed.comm as dist + +from ...allocator import empty_from +from ...inference_utils import ActivationType, DtypeEnum +from ...model_implementations import * +from ...modules.configs import * +from ...modules.interfaces import * +from ...ragged import RaggedBatchWrapper + +from .falcon_containers import FalconNonTransformerContainer, FalconTransformerContainer + + +class FalconInferenceModel(DSTransformerModelBase): + """ + Inference model implementation for ragged batching for Llama-2 models. + """ + + _non_transformer: Optional[FalconNonTransformerContainer] + """ + Embed + unembed container. Specializing the type annotation. + """ + + _transformer: Optional[Iterable[FalconTransformerContainer]] + """ + Per-layer transformer container. Specializing the type annotation. + """ + """ + Properties inherited from `DSInferenceModelBase` + """ + + @property + def max_sequence_length(self) -> int: + return self._config.max_seq_length + + """ + Properties inherited from `DSTransformerModelBase` + """ + + @property + def num_layers(self) -> int: + return self._config.num_hidden_layers + + @property + def model_dim(self) -> int: + return self._config.hidden_size + + @property + def vocab_size(self) -> int: + return self._config.vocab_size + + @property + def head_size(self) -> int: + return self.model_dim // self.n_heads + + @property + def n_heads(self) -> int: + return self._config.num_attention_heads + + @property + def intermediate_dim(self) -> int: + return 4 * self._config.hidden_size + + @property + def n_heads_kv(self) -> int: + return self._config.num_kv_heads if (self._config.new_decoder_architecture + or not self._config.multi_query) else 1 + + @property + def activation_dtype(self) -> DtypeEnum: + if self._config.torch_dtype == torch.float16: + return DtypeEnum.fp16 + elif self._config.torch_dtype == torch.bfloat16: + return DtypeEnum.bf16 + else: + raise NotImplementedError("Only fp16 and bf16 are supported") + + @property + def mlp_activation_fn(self) -> ActivationType: + return ActivationType.GELU + + @property + def norm_type(self) -> NormTypeEnum: + return NormTypeEnum.LayerNorm + + @property + def positional_embedding_type(self) -> PositionalEmbeddingType: + return PositionalEmbeddingType.rotate_half + + """ + Forward implementations + """ + + def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs the embedding lookup prior to running the transformer of the model. + + Arguments: + ragged_batch (RaggedBatchWrapper): The batch to embed. + + Returns: + torch.Tensor: The embedded batch. + """ + embed = self.embed(ragged_batch, self._non_transformer.word_emb) + + if embed.shape[-1] != self.model_dim: + raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") + + return embed + + def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, + ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead + optimization to fuse the layer norm of the next layer into the current layer. + + Arguments: + layer_idx (int): The index of the layer to execute. + residual (torch.Tensor): The residual tensor from the previous layer. + hidden_states (torch.Tensor): The hidden states from the previous layer. This is the + hidden states after pre normalization. + ragged_batch_info (RaggedBatchWrapper): The batch metadata. + """ + assert self.config.parallel_attn, "Only parallel attention implementation is supported" + + cur_params = self._transformer[layer_idx] + kv_cache = self.state_manager.get_cache(layer_idx) + + attn_ln_out = hidden_states + attn_hidden_state = self.qkv(attn_ln_out, cur_params.qkv_w, b=None) + attn_hidden_state = self.attn(attn_hidden_state, kv_cache, ragged_batch_info) + attention_output = self.attn_out(attn_hidden_state, cur_params.attn_out_w, b=None) + + if self.config.new_decoder_architecture: + residual, mlp_ln_out = self.norm(residual, + None, + gamma=cur_params.ln_mlp_gamma, + beta=cur_params.ln_mlp_beta) + else: + mlp_ln_out = hidden_states + + mlp_hidden_state = self.mlp_1(mlp_ln_out, cur_params.mlp_1_w, b=None) + mlp_output = self.mlp_2(mlp_hidden_state, cur_params.mlp_2_w, b=None) + + mlp_output.add_(attention_output) + + if self.tp_size > 1: + dist.all_reduce(mlp_output, group=self._base_mp_group) + + if layer_idx != self.num_layers - 1: + next_params = self._transformer[layer_idx + 1] + residual, mlp_output = self.norm(residual, + mlp_output, + next_params.ln_attn_gamma, + beta=next_params.ln_attn_beta) + else: + # On last layer, we just need to perform the residual add. Adding into the residual + # here is safe. + residual.add_(mlp_output) + + return residual, mlp_output + + def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs unembedding of the hidden states to logits. This will only sample the final + token of each sequence. + """ + logits = self.unembed(hidden_states, + self._non_transformer.word_unembed, + ragged_batch_info, + gamma=self._non_transformer.final_norm_gamma, + beta=self._non_transformer.final_norm_beta) + + if self.tp_size > 1: + comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) + full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size)) + + dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group) + + full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size)) + + return full_logits + else: + return logits + + def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: + residual = self._forward_embed(wrapped_batch) + + residual, hidden_states = self.norm(residual, + None, + gamma=self._transformer[0].ln_attn_gamma, + beta=self._transformer[0].ln_attn_beta) + + for layer_idx in range(self.num_layers): + residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states, + wrapped_batch) + + return self._forward_unembed(residual, wrapped_batch) diff --git a/deepspeed/inference/v2/model_implementations/falcon/falcon_policy.py b/deepspeed/inference/v2/model_implementations/falcon/falcon_policy.py new file mode 100644 index 000000000000..5672d45a8d13 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/falcon/falcon_policy.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any + +from ...config_v2 import RaggedInferenceEngineConfig +from ...model_implementations.inference_policy_base import ContainerMap, InferenceV2Policy +from ...model_implementations.falcon.falcon_containers import FalconNonTransformerContainer, FalconTransformerContainer +from ...model_implementations.falcon.falcon_containers import FalconNewArchTransformerContainer +from ...model_implementations.falcon.falcon_model import FalconInferenceModel + + +class FalconPolicy(InferenceV2Policy): + + def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> FalconInferenceModel: + return FalconInferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) + + def build_container_map(self) -> ContainerMap: + map = ContainerMap() + + trans_container_cls = FalconNewArchTransformerContainer if self._model_config.new_decoder_architecture else FalconTransformerContainer + transformer_containers = [trans_container_cls(self.model) for _ in range(self.model.num_layers)] + + map.set_transformer_params(['transformer.h'], transformer_containers) + + map.set_non_transformer_params(FalconNonTransformerContainer(self.model)) + + map.set_unmapped_params( + [f'model.layers.{i}.self_attn.rotary_emb.inv_freq' for i in range(self.model.num_layers)]) + + return map diff --git a/deepspeed/inference/v2/modules/implementations/embedding/ragged_embedding.py b/deepspeed/inference/v2/modules/implementations/embedding/ragged_embedding.py index 6782bcae81c8..90cdd39d1be7 100644 --- a/deepspeed/inference/v2/modules/implementations/embedding/ragged_embedding.py +++ b/deepspeed/inference/v2/modules/implementations/embedding/ragged_embedding.py @@ -32,7 +32,7 @@ def supports_config(config: DSEmbeddingsConfig) -> bool: if config.use_token_type: return False - if config.output_normalization != None: + if config.output_normalization is not None: return False try: diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index c235cc766209..6f545d4cb13b 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -385,7 +385,8 @@ def update_mp_params(self, child): return for param in [ "n_heads", "inner_dim", "num_heads", "num_kv", "num_attention_heads", "num_attn_heads", - "all_head_size", "embed_dim", "hidden_size", "num_key_value_heads", "num_kv_heads" + "all_head_size", "embed_dim", "hidden_size", "num_key_value_heads", "num_kv_heads", "kv_n_heads", + "d_model" ]: if hasattr(child, param): param_val = getattr(child, param) @@ -450,7 +451,7 @@ def get_model_num_kv_heads(self, config): for name in kv_head_names: if hasattr(config, name): num_kv_heads = getattr(config, name) - if num_kv_heads != None: + if num_kv_heads is not None: break return num_kv_heads diff --git a/deepspeed/module_inject/fusedqkv_utils.py b/deepspeed/module_inject/fusedqkv_utils.py index 0fca37c6b5e8..d61e78ab8d0e 100644 --- a/deepspeed/module_inject/fusedqkv_utils.py +++ b/deepspeed/module_inject/fusedqkv_utils.py @@ -28,7 +28,7 @@ def require_tp_fused_qkvw(name, mp_size): def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index): - if src == None: + if src is None: return fused_type_dict = { 'CodeGenBlock': 'codegentype', diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index d37d8c163f07..5b7d2209d89e 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -600,7 +600,7 @@ def skip_level_0_prefix(model, state_dict): if key is None: key = re.match(r"(.*?)Model", model) # if keys start with 'model.', don't skip level 0 prefix - if state_dict != None: + if state_dict is not None: for item in state_dict.keys(): if re.match("^model[.]", item): return False diff --git a/deepspeed/module_inject/tp_shard.py b/deepspeed/module_inject/tp_shard.py index 8e2fa78d883f..302b3c33c953 100644 --- a/deepspeed/module_inject/tp_shard.py +++ b/deepspeed/module_inject/tp_shard.py @@ -20,8 +20,8 @@ def get_num_kv_heads(): def get_shard_size(total_size, mp_size, rank=None): global num_kv_heads # When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division - if num_kv_heads != None: - if (rank == None): + if num_kv_heads is not None: + if rank is None: rank = dist.get_rank() my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0) return total_size * my_slices // num_kv_heads diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index a02ddbe86403..888505279290 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -37,7 +37,8 @@ def __init__(self, norm_type=2, allgather_bucket_size=5000000000, dp_process_group=None, - timers=None): + timers=None, + grad_acc_dtype=None): super().__init__() see_memory_usage('begin bf16_optimizer', force=True) self.timers = timers @@ -45,6 +46,10 @@ def __init__(self, self.param_names = param_names self.using_real_optimizer = not isinstance(self.optimizer, DummyOptim) + assert grad_acc_dtype in [torch.float32, torch.bfloat16 + ], f"BF16Optimizer: Unsupported gradient accumulation data type: {grad_acc_dtype}" + self.grad_acc_dtype = grad_acc_dtype + self.clip_grad = clip_grad self.norm_type = norm_type self.mpu = mpu @@ -119,7 +124,8 @@ def _setup_for_real_optimizer(self): num_elem_list = [t.numel() for t in self.bf16_groups[i]] # create fp32 gradients - self.fp32_groups_gradients_flat.append(torch.zeros_like(self.bf16_groups_flat[i], dtype=torch.float32)) + self.fp32_groups_gradients_flat.append( + torch.zeros_like(self.bf16_groups_flat[i], dtype=self.grad_acc_dtype)) # track individual fp32 gradients for entire model fp32_gradients = self._split_flat_tensor(flat_tensor=self.fp32_groups_gradients_flat[i], @@ -204,10 +210,16 @@ def initialize_optimizer_states(self): """ for param_partition, grad_partition in zip(self.fp32_groups_flat_partition, self.fp32_groups_gradient_flat_partition): - param_partition.grad = grad_partition + # In case of grad acc dtype different than FP32, need to cast to high precision. + param_partition.grad = grad_partition.to( + param_partition.dtype) if grad_partition.dtype != param_partition.dtype else grad_partition self.optimizer.step() + if self.grad_acc_dtype is not torch.float32: + for param_partition in self.fp32_groups_flat_partition: + param_partition.grad = None + self.clear_hp_grads() def _split_flat_tensor(self, flat_tensor, num_elem_list): diff --git a/deepspeed/runtime/comm/hccl.py b/deepspeed/runtime/comm/hccl.py new file mode 100644 index 000000000000..09fb11a731b8 --- /dev/null +++ b/deepspeed/runtime/comm/hccl.py @@ -0,0 +1,124 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import numpy as np +import torch +import torch_npu +import deepspeed.comm as dist + + +class HcclBackend(object): + + def __init__(self, mpu=None): + if mpu is None: + self.world_group = dist.new_group(ranks=range(dist.get_world_size())) + else: + self.mpu = mpu + self.world_group = self.mpu.get_data_parallel_group() + self.size = dist.get_world_size(group=self.world_group) + self.rank = dist.get_rank(group=self.world_group) + + def my_igather(self, rank, size, group, sendbuf, recvbuf, root): + req = [] + if rank == root: + for idx in range(size): + if idx != rank: + req.append(dist.irecv(recvbuf[idx], src=idx, group=group)) + else: + recvbuf[rank] = sendbuf + else: + req.append(dist.isend(sendbuf, group=group, dst=root)) + return req + + def my_gather(self, rank, size, group, sendbuf, recvbuf, root): + if rank == root: + for idx in range(size): + if idx != rank: + dist.recv(recvbuf[idx], src=idx, group=group) + else: + recvbuf[rank] = sendbuf + else: + dist.send(sendbuf, group=group, dst=root) + + def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_error, local_rank): + original_shape = buffer_m.size() + if len(original_shape) > 1: + buffer_m = torch.flatten(buffer_m) + + # align size of original_buffer and error + original_size = buffer_m.numel() + worker_error_size = worker_error.numel() + if original_size != worker_error_size: + empty_tensor = torch.zeros(worker_error_size - original_size, device=buffer_m.device) + buffer_m = torch.cat([buffer_m, empty_tensor]) + + buffer_m.add_(worker_error) + worker_scale = torch.linalg.norm(buffer_m) / np.sqrt(torch.numel(buffer_m)) + + worker_error.set_(buffer_m - worker_scale * buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)) + + sign_list_packed_tmp = torch_npu.npu_sign_bits_pack(buffer_m, self.size).type(torch.int8) + + recvbuf_sign = torch.zeros([self.size, len(sign_list_packed_tmp[self.rank])], + dtype=sign_list_packed_tmp[0].dtype, + device=sign_list_packed_tmp.device) + + sign_list_packed = [sign_list_packed_tmp[idx] for idx in range(self.size)] + + recvbuf_scale = [ + torch.zeros(1, dtype=worker_scale.dtype, device=torch.device(local_rank)) for _ in range(self.size) + ] + + # communication phase 1 + # all to all for sign + dist.all_to_all_single(recvbuf_sign, torch.stack(sign_list_packed), group=self.world_group) + # all gather for scale + dist.all_gather(recvbuf_scale, worker_scale, group=self.world_group) + + flattened_recvbuf_sign = recvbuf_sign.type(torch.uint8).flatten() + compensated_server_m = torch_npu.npu_sign_bits_unpack(flattened_recvbuf_sign, self.size, torch.float32) \ + .mul_(torch.stack(recvbuf_scale).mul_(1 / self.size)).sum(0) + + compensated_server_m.add_(server_error) + + server_scale = torch.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel()) + + server_error.set_(compensated_server_m - + server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)) + + server_sign_packed = torch_npu.npu_sign_bits_pack(compensated_server_m, 1).type(torch.int8) + + # recvbuf_sign_server + recvbuf_sign_server_tmp = torch.zeros([self.size, len(server_sign_packed[0])], + dtype=recvbuf_sign.dtype, + device=server_sign_packed.device) + + recvbuf_sign_server = [recvbuf_sign_server_tmp[idx] for idx in range(self.size)] + + # recvbuf_scale_server + recvbuf_scale_server_tmp = torch.zeros([self.size, 1], + dtype=worker_scale.dtype, + device=server_sign_packed.device) + + recvbuf_scale_server = [recvbuf_scale_server_tmp[idx] for idx in range(self.size)] + + # communication Phase 2 + dist.all_gather(recvbuf_sign_server, server_sign_packed[0], group=self.world_group) + dist.all_gather(recvbuf_scale_server, server_scale, group=self.world_group) + + recvbuf_sign_server = torch.stack(recvbuf_sign_server) + + flattened_recvbuf_sign_server = recvbuf_sign_server.type(torch.uint8).flatten() + + buffer_m.data.copy_( + torch_npu.npu_sign_bits_unpack(flattened_recvbuf_sign_server, self.size, + torch.float32).mul_(recvbuf_scale_server_tmp).flatten().data) + + if original_size != worker_error_size: + buffer_m = buffer_m[0:original_size] + if len(original_shape) > 1: + buffer_m = buffer_m.reshape(original_shape) + + return buffer_m diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 8f62f36f328e..44b44c79ba55 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1180,9 +1180,15 @@ def _do_optimizer_sanity_check(self, basic_optimizer): # data type checks elif model_dtype == grad_accum_dtype: if model_dtype == torch.bfloat16: - raise NotImplementedError( - "Bfloat16 wrapper must use a gradient accumulation type of fp32, enable ZeRO to use Bfloat16 gradient accumulation" - ) + if self.pipeline_parallelism: + logger.warning( + "**** BF16 gradient accumulation is not safe numerically with large number of accumulation steps, proceed with caution *****" + ) + return BFLOAT16 + else: + raise NotImplementedError( + "Bfloat16 wrapper must use a gradient accumulation type of fp32, enable ZeRO to use Bfloat16 gradient accumulation" + ) if model_dtype == torch.float16: return FP16 # else optimizer_wrapper = None @@ -1444,7 +1450,8 @@ def _configure_bf16_optimizer(self, optimizer): clip_grad=clip_grad, allgather_bucket_size=self.zero_allgather_bucket_size(), dp_process_group=self.seq_data_parallel_group, - timers=timers) + timers=timers, + grad_acc_dtype=self.get_data_types()[1]) return optimizer diff --git a/deepspeed/runtime/fp16/onebit/adam.py b/deepspeed/runtime/fp16/onebit/adam.py index 236eea8cadc5..ae3e5f573850 100644 --- a/deepspeed/runtime/fp16/onebit/adam.py +++ b/deepspeed/runtime/fp16/onebit/adam.py @@ -70,8 +70,6 @@ def __init__(self, super(OnebitAdam, self).__init__(params, defaults) self.eps_mode = 0 if eps_inside_sqrt else 1 - assert (dist.is_initialized()) - self.comm_time = 0.0 self.step_time = 0.0 self.ave_step = 1 @@ -86,22 +84,23 @@ def __init__(self, self.comm_backend_name = comm_backend_name + assert dist.is_initialized(), "Please initialize the torch distributed backend." # Empty initializer. Set handle based on the comm backend as follows. self.comm_backend_handle = None - if self.comm_backend_name == 'nccl': assert ( required_torch_version(min_version=1.8) ), "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend" - assert dist.is_initialized() == True, "Please initialize the torch distributed backend." from deepspeed.runtime.comm.nccl import NcclBackend self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') self.comm_backend_handle = NcclBackend(self.deepspeed.mpu) - elif self.comm_backend_name == 'mpi': from deepspeed.runtime.comm.mpi import MpiBackend self.comm_backend_handle = MpiBackend(cuda_aware) - + elif self.comm_backend_name == 'hccl': + from deepspeed.runtime.comm.hccl import HcclBackend + self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') + self.comm_backend_handle = HcclBackend(self.deepspeed.mpu) self.size = self.comm_backend_handle.size self.divider = int(self.size * 8 / np.gcd(self.size, 8)) diff --git a/deepspeed/runtime/fp16/onebit/lamb.py b/deepspeed/runtime/fp16/onebit/lamb.py index 0662fabeeee1..9cd2e0f25648 100644 --- a/deepspeed/runtime/fp16/onebit/lamb.py +++ b/deepspeed/runtime/fp16/onebit/lamb.py @@ -93,8 +93,6 @@ def __init__(self, super(OnebitLamb, self).__init__(params, defaults) self.eps_mode = 0 if eps_inside_sqrt else 1 - assert (dist.is_initialized()) - self.deepspeed = deepspeed self.lamb_freeze_key = False self.initialize = False @@ -108,21 +106,23 @@ def __init__(self, self.comm_backend_name = comm_backend_name + assert dist.is_initialized(), "Please initialize the torch distributed backend." # Empty initializer. Set handle based on the comm backend as follows. self.comm_backend_handle = None - if self.comm_backend_name == 'nccl': assert ( required_torch_version(min_version=1.8) ), "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend" - assert dist.is_initialized() == True, "Please initialize the torch distributed backend." from deepspeed.runtime.comm.nccl import NcclBackend self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') self.comm_backend_handle = NcclBackend(self.deepspeed.mpu) - elif self.comm_backend_name == 'mpi': from deepspeed.runtime.comm.mpi import MpiBackend self.comm_backend_handle = MpiBackend(cuda_aware) + elif self.comm_backend_name == 'hccl': + from deepspeed.runtime.comm.hccl import HcclBackend + self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') + self.comm_backend_handle = HcclBackend(self.deepspeed.mpu) self.size = self.comm_backend_handle.size @@ -161,7 +161,7 @@ def step(self, closure=None, grads=None): else: grads_group = grads - #remove the previous stats + # remove the previous stats del self.lamb_coeffs[:] if self.lamb_freeze_key: diff --git a/deepspeed/runtime/fp16/onebit/zoadam.py b/deepspeed/runtime/fp16/onebit/zoadam.py index 922263ad6a76..9ef671e7e3b7 100644 --- a/deepspeed/runtime/fp16/onebit/zoadam.py +++ b/deepspeed/runtime/fp16/onebit/zoadam.py @@ -83,8 +83,6 @@ def __init__(self, super(ZeroOneAdam, self).__init__(params, defaults) self.eps_mode = 0 if eps_inside_sqrt else 1 - assert (dist.is_initialized()) - self.deepspeed = deepspeed self.initialize = False self.cuda_aware = cuda_aware @@ -99,22 +97,23 @@ def __init__(self, self.comm_backend_name = comm_backend_name + assert dist.is_initialized(), "Please initialize the torch distributed backend." # Empty initializer. Set handle based on the comm backend as follows. self.comm_backend_handle = None - if self.comm_backend_name == 'nccl': assert ( required_torch_version(min_version=1.8) ), "Please use torch 1.8 or greater to enable NCCL backend in 0/1 Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend" - assert dist.is_initialized() == True, "Please initialize the torch distributed backend." from deepspeed.runtime.comm.nccl import NcclBackend self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') self.comm_backend_handle = NcclBackend(self.deepspeed.mpu) - elif self.comm_backend_name == 'mpi': from deepspeed.runtime.comm.mpi import MpiBackend self.comm_backend_handle = MpiBackend(cuda_aware) - + elif self.comm_backend_name == 'hccl': + from deepspeed.runtime.comm.hccl import HcclBackend + self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') + self.comm_backend_handle = HcclBackend(self.deepspeed.mpu) self.size = self.comm_backend_handle.size self.divider = int(self.size * 8 / np.gcd(self.size, 8)) diff --git a/deepspeed/runtime/hybrid_engine.py b/deepspeed/runtime/hybrid_engine.py index da6f7a9be54e..a991c4304563 100644 --- a/deepspeed/runtime/hybrid_engine.py +++ b/deepspeed/runtime/hybrid_engine.py @@ -385,14 +385,20 @@ def eval(self): self._total_latency = self._total_latency + latency self._iters = self._iters + 1 if not dist.is_initialized() or dist.get_rank() == 0: + if self._total_batch_size is not None: + cur_samples_p_sec = f'|CurSamplesPerSec={(1 / latency * self._total_batch_size):.2f} ' + avg_samples_p_sec = f'|AvgSamplesPerSec={(1 / (self._total_latency / self._iters) * self._total_batch_size):.2f}' + else: + cur_samples_p_sec = '' + avg_samples_p_sec = '' others = latency - (self._generate_latency + self._training_latency) print(f'|E2E latency={(latency):.2f}s ' + \ f'|Gather latency={self._gather_latency:.2f}s ({(self._gather_latency / latency * 100):.2f}%) ' f'|Generate time={(self._generate_latency):.2f}s ({(self._generate_latency / latency * 100):.2f}%) ' + \ f'|Training time={(self._training_latency):.2f}s ({(self._training_latency / latency * 100):.2f}%) ' + \ - f'|Others={others:.2f} ({(others / latency * 100):.2f}%)' - f'|CurSamplesPerSec={(1 / latency * self._total_batch_size):.2f} ' + \ - f'|AvgSamplesPerSec={(1 / (self._total_latency / self._iters) * self._total_batch_size):.2f}') + f'|Others={others:.2f} ({(others / latency * 100):.2f}%)' + \ + cur_samples_p_sec + \ + avg_samples_p_sec) self._t_start = time.time() self._training_latency = 0 super().eval() diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 48ccdbc29bf6..bc7a782e590c 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -14,7 +14,6 @@ import psutil import gc from math import sqrt -from bisect import bisect_left from packaging import version as pkg_version import torch @@ -570,67 +569,43 @@ def partition_uniform(num_items, num_parts): return parts -def _lprobe(weights, num_parts, bottleneck): - num_items = len(weights) - total_weight = weights[-1] - - # initialize partitioning - parts = [0] * (num_parts + 1) - for p in range(1, num_parts + 1): - parts[p] = num_items - - bsum = bottleneck # running sum of target weight for pth partition - chunksize = num_items // num_parts - step = chunksize - for p in range(1, num_parts): - # Jump to the next bucket - while (step < num_items) and (weights[step] < bsum): - step += chunksize - - # Find the end index of partition p - parts[p] = bisect_left(weights, bsum, lo=step - chunksize, hi=min(step, num_items)) - # Nothing more to partition, return early - if parts[p] == num_items: - # See if the current partition is overweight. - part_size = weights[-1] - weights[parts[p - 1]] - return parts, part_size < bottleneck - - # Next partition target - bsum = weights[parts[p] - 1] + bottleneck - - return parts, bsum >= total_weight - - -def _rb_partition_balanced(weights, num_parts, eps): - total_weight = weights[-1] - lower = total_weight / num_parts # best case heaviest partition - upper = total_weight # worst case heaviest partition - - # Do a binary search for the best partitioning - while upper > lower + eps: - mid = lower + ((upper - lower) / 2) - parts, success = _lprobe(weights, num_parts, mid) - if success: - upper = mid - else: - lower = mid + eps - return upper - - -def partition_balanced(weights, num_parts, eps=1e-3): - num_items = len(weights) - # First check for the trivial edge case - if num_items <= num_parts: - return partition_uniform(num_items, num_parts) - - weights_ = prefix_sum_inc(weights) - - # Find the smallest bottleneck (weight of heaviest partition) - bottleneck = _rb_partition_balanced(weights_, num_parts, eps=eps) - - # Now compute that partitioning - parts, success = _lprobe(weights_, num_parts, bottleneck) - assert success +def partition_balanced(weights, num_parts): + """ + use dynamic programming solve `The Linear Partition Problem`. + see https://www8.cs.umu.se/kurser/TDBAfl/VT06/algorithms/BOOK/BOOK2/NODE45.HTM + """ + import numpy as np + n = len(weights) + m = num_parts + + if n <= m: + return partition_uniform(n, m) + + dp_max = np.full((n + 1, m + 1), np.inf) + dp_min = np.full((n + 1, m + 1), np.inf) + dp_cost = np.full((n + 1, m + 1), np.inf) + position = np.zeros((n + 1, m + 1), dtype=int) + prefix_sum = np.zeros((n + 1)) + prefix_sum[1:] = np.cumsum(weights) + + dp_max[0, 0] = 0 + dp_cost[0, 0] = 0 + for i in range(1, n + 1): + for j in range(1, min(i, m) + 1): + for k in range(i): + max_sum = max(dp_max[k, j - 1], prefix_sum[i] - prefix_sum[k]) + min_sum = min(dp_min[k, j - 1], prefix_sum[i] - prefix_sum[k]) + cost = max_sum - min_sum + if dp_cost[i, j] >= cost: + dp_cost[i, j] = cost + dp_max[i, j] = max_sum + dp_min[i, j] = min_sum + position[i, j] = k + + parts = [n] + for i in reversed(range(1, m + 1)): + parts.append(position[parts[-1], i]) + parts.reverse() return parts diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 63aa4ca4c48e..30a168dcd396 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1215,7 +1215,7 @@ def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None: @instrument_w_nvtx def __avg_scatter_contiguous_grads(self, buffer_to_reduce: Tensor) -> List[Tensor]: dtype = buffer_to_reduce.dtype - if self.communication_data_type == self.dtype: + if self.communication_data_type != dtype: buffer_to_reduce = buffer_to_reduce.to(self.communication_data_type) if self.postscale_gradients and self.gradient_predivide_factor != 1.0: buffer_to_reduce = buffer_to_reduce.div_(self.gradient_predivide_factor) @@ -1984,7 +1984,7 @@ def step(self, closure=None): # warn user about caching allocator flushes memory_stats = get_accelerator().memory_stats() alloc_retries = memory_stats.get("num_alloc_retries") - if alloc_retries == None: + if alloc_retries is None: alloc_retries = 0 if alloc_retries > self.n_caching_allocator_flushes: if dist.get_rank() == 0: @@ -2541,7 +2541,7 @@ def load_state_dict(self, # when use loading checkpoint serial, after finish loading, we need to # delete the temp state_dict_list variable to save memory, then trigger # the next rank's loading - if load_serial != None: + if load_serial is not None: load_serial += 1 rank = dist.get_rank(group=self.dp_process_group) local_rank = dist.get_local_rank() diff --git a/op_builder/cpu/comm.py b/op_builder/cpu/comm.py index ec908eb0622b..b26328341081 100644 --- a/op_builder/cpu/comm.py +++ b/op_builder/cpu/comm.py @@ -35,7 +35,7 @@ def is_compatible(self, verbose=True): def extra_ldflags(self): ccl_root_path = os.environ.get("CCL_ROOT") - if ccl_root_path == None: + if ccl_root_path is None: raise ValueError( "Didn't find CCL_ROOT, install oneCCL from https://github.com/oneapi-src/oneCCL and source its environment variable" ) diff --git a/tests/unit/launcher/test_ds_arguments.py b/tests/unit/launcher/test_ds_arguments.py index a2d06e7601ab..ee6d4ce6b7be 100644 --- a/tests/unit/launcher/test_ds_arguments.py +++ b/tests/unit/launcher/test_ds_arguments.py @@ -40,7 +40,7 @@ def test_no_ds_arguments(): assert args.deepspeed == False assert hasattr(args, 'deepspeed_config') - assert args.deepspeed_config == None + assert args.deepspeed_config is None def test_no_ds_enable_argument(): @@ -74,7 +74,7 @@ def test_no_ds_config_argument(): assert args.deepspeed == True assert hasattr(args, 'deepspeed_config') - assert args.deepspeed_config == None + assert args.deepspeed_config is None def test_no_ds_parser(): diff --git a/tests/unit/runtime/zero/test_zero_config.py b/tests/unit/runtime/zero/test_zero_config.py index db9fd6516034..8b20eca8c7d2 100644 --- a/tests/unit/runtime/zero/test_zero_config.py +++ b/tests/unit/runtime/zero/test_zero_config.py @@ -48,12 +48,12 @@ def test_zero_config_overlapcomm(): def test_zero_config_offload_configs(): config = DeepSpeedZeroConfig() - assert config.offload_param == None - assert config.offload_optimizer == None + assert config.offload_param is None + assert config.offload_optimizer is None config = DeepSpeedZeroConfig(**{"offload_param": None, "offload_optimizer": None}) - assert config.offload_param == None - assert config.offload_optimizer == None + assert config.offload_param is None + assert config.offload_optimizer is None config = DeepSpeedZeroConfig(**{"offload_param": {}, "offload_optimizer": {}}) assert isinstance(config.offload_param, DeepSpeedZeroOffloadParamConfig) diff --git a/tests/unit/utils/test_partition_balanced.py b/tests/unit/utils/test_partition_balanced.py new file mode 100644 index 000000000000..e7285e478c53 --- /dev/null +++ b/tests/unit/utils/test_partition_balanced.py @@ -0,0 +1,25 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from deepspeed.runtime import utils as ds_utils + + +def check_partition(weights, num_parts, target_diff): + result = ds_utils.partition_balanced(weights=weights, num_parts=num_parts) + + parts_sum = [] + for b, e in zip(result[:-1], result[1:]): + parts_sum.append(sum(weights[b:e])) + + assert max(parts_sum) - min( + parts_sum + ) == target_diff, f"ds_utils.partition_balanced(weights={weights}, num_parts={num_parts}) return {result}" + + +def test_partition_balanced(): + check_partition([1, 2, 1], 4, target_diff=2) + check_partition([1, 1, 1, 1], 4, target_diff=0) + check_partition([1, 1, 1, 1, 1], 4, target_diff=1) + check_partition([1, 1, 1, 1, 0, 1], 4, target_diff=1)