diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 2d5332578edc..19dbe73726f7 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -452,14 +452,17 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, unsigned layer_id, unsigned num_layers, at::Tensor& alibi, - float rope_theta) + float rope_theta, + bool is_prompt, + std::optional token_idx, + std::optional position_ids) { unsigned bsz = query_key_value.size(0); unsigned seq_len = query_key_value.size(1); int k = query_key_value.size(2) / (heads + 2 * (num_kv > 0 ? num_kv : heads)); unsigned hidden_dim = heads * k; - bool is_prompt = (seq_len > 1); + is_prompt = (seq_len > 1); if (is_prompt) InferenceContext::Instance().reset_tokens(seq_len); unsigned soft_len = InferenceContext::Instance().current_tokens(); @@ -2031,7 +2034,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) "DeepSpeed memory allocation for GPT inference with " #_name " (CUDA)"); \ m.def("dequantize_" #_name, \ &ds_dequantize<_dtype>, \ - "DeepSpeed dequantize with " #_name " (CUDA)") + "DeepSpeed dequantize with " #_name " (CUDA)"); DEF_OPS(fp32, float); DEF_OPS(fp16, __half); diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 68836ceb523c..6574d49fb132 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -53,12 +53,7 @@ def __init__(self, model, config): DS_INFERENCE_ENABLED = True super().__init__() - - # Have to import here because inference_module is a global, but python - # globals only work at the module level and will not be updated unless - # we import it each time we init a new inference engine. - from ..model_implementations.transformers.ds_transformer import inference_module - if inference_module is not None: + if DeepSpeedTransformerInference.workspace is not None: self.destroy() self.module = model @@ -191,15 +186,11 @@ def __init__(self, model, config): self._is_compiled = False def destroy(self): - # Have to import here because inference_module is a global, but python - # globals only work at the module level and will not be updated unless - # we import it each time we init a new inference engine. - from ..model_implementations.transformers.ds_transformer import inference_module DeepSpeedTransformerInference.layer_id = 0 DeepSpeedSelfAttention.num_layers = 0 - if inference_module is not None: - inference_module.release_workspace() - inference_module = None + if DeepSpeedTransformerInference.workspace.is_allocated(): + DeepSpeedTransformerInference.workspace.release_workspace() + DeepSpeedTransformerInference.workspace = None def profile_model_time(self, use_cuda_events=True): if not self.model_profile_enabled and not self._config.enable_cuda_graph: diff --git a/deepspeed/model_implementations/transformers/ds_llama2.py b/deepspeed/model_implementations/transformers/ds_llama2.py index 7d9eb4113a8a..325bfb4f7e18 100644 --- a/deepspeed/model_implementations/transformers/ds_llama2.py +++ b/deepspeed/model_implementations/transformers/ds_llama2.py @@ -4,11 +4,8 @@ # DeepSpeed Team import torch -from deepspeed import comm as dist from deepspeed.model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference -inference_module = None - class DeepSpeedLlama2Inference(DeepSpeedTransformerInference): """Initialize the DeepSpeed OPT Transformer Layer. @@ -27,18 +24,10 @@ def forward(self, *args, **kwargs): input = args[0] input_mask = None - # Allocate memory only on first layer forward - if self.config.layer_id == 0 and self._alloc_workspace: - self.allocate_workspace(self.config.hidden_size, self.config.heads, - input.size()[1], - input.size()[0], DeepSpeedTransformerInference.layer_id, self.config.mp_size, - self.config.bigscience_bloom, - dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens, - self.config.min_out_tokens) - self._alloc_workspace = False - get_present = True + self.allocate_workspace(input.size()) + # We set the prev key/value to None when there is a prompt if input.shape[1] > 1: self.layer_past = None diff --git a/deepspeed/model_implementations/transformers/ds_transformer.py b/deepspeed/model_implementations/transformers/ds_transformer.py index d87d0de997b5..360113b78a3d 100644 --- a/deepspeed/model_implementations/transformers/ds_transformer.py +++ b/deepspeed/model_implementations/transformers/ds_transformer.py @@ -6,19 +6,18 @@ import torch import torch.nn as nn from deepspeed import comm as dist +from deepspeed.ops.transformer.inference.op_binding.layer_norm import LayerNormOp from deepspeed.utils.logging import log_dist from deepspeed.ops.transformer.inference.ds_mlp import DeepSpeedMLP from deepspeed.ops.transformer.inference.ds_attention import DeepSpeedSelfAttention, BloomSelfAttention +from deepspeed.ops.transformer.inference.op_binding.workspace import WorkspaceOp from deepspeed.accelerator import get_accelerator -from deepspeed.ops.op_builder import InferenceBuilder import deepspeed if deepspeed.HAS_TRITON: from deepspeed.ops.transformer.inference.triton.mlp import TritonMLP from deepspeed.ops.transformer.inference.triton.attention import TritonSelfAttention -inference_module = None - class DeepSpeedTransformerInference(nn.Module): """Initialize the DeepSpeed Transformer Layer. @@ -37,6 +36,7 @@ class DeepSpeedTransformerInference(nn.Module): for specific downstream tasks. """ layer_id = 0 + workspace = None def __init__(self, config, @@ -52,10 +52,6 @@ def __init__(self, DeepSpeedTransformerInference.layer_id += 1 data_type = torch.half if self.config.dtype == torch.int8 else self.config.dtype - global inference_module - if inference_module is None: - builder = InferenceBuilder() - inference_module = builder.load() if DeepSpeedTransformerInference.layer_id == 1: log_dist(f"DeepSpeed-Inference config: {self.config.__dict__}", [0]) @@ -88,22 +84,25 @@ def __init__(self, self.norm_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device), requires_grad=False) self.layer_past = None - try: - if config.dtype == torch.float32: - self.allocate_workspace = inference_module.allocate_workspace_fp32 - elif config.dtype == torch.bfloat16: - self.allocate_workspace = inference_module.allocate_workspace_bf16 - else: - self.allocate_workspace = inference_module.allocate_workspace_fp32 - self._alloc_workspace = True - except AttributeError: - self.allocate_workspace = None - self._alloc_workspace = False + self.layer_norm = LayerNormOp() + if DeepSpeedTransformerInference.workspace is None: + DeepSpeedTransformerInference.workspace = WorkspaceOp(self.config) + self._should_allocate_workspace = True + + def allocate_workspace(self, size): + # Allocate memory only on first layer forward + if self.config.layer_id == 0 and self._should_allocate_workspace: + DeepSpeedTransformerInference.workspace.allocate_workspace( + self.config.hidden_size, self.config.heads, size[1], size[0], DeepSpeedTransformerInference.layer_id, + self.config.mp_size, self.config.bigscience_bloom, + dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens, + self.config.min_out_tokens) + self._should_allocate_workspace = False @classmethod def reset_cache(cls): - if inference_module is not None: - inference_module.reset_cache() + if cls.workspace is not None: + cls.workspace.reset_cache() def forward( self, @@ -136,15 +135,7 @@ def forward( input_mask = (input_mask if attn_mask is None else attn_mask) if attention_mask is None else attention_mask - # Allocate memory only on first layer forward - if self.config.layer_id == 0 and self._alloc_workspace: - self.allocate_workspace(self.config.hidden_size, self.config.heads, - input.size()[1], - input.size()[0], DeepSpeedTransformerInference.layer_id, self.config.mp_size, - self.config.bigscience_bloom, - dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens, - self.config.min_out_tokens) - self._alloc_workspace = False + self.allocate_workspace(input.size()) get_present = (get_present or get_key_value or use_cache) input_mask = input_mask if attention_mask is None else attention_mask @@ -178,14 +169,15 @@ def forward( output_attentions, self.norm_w, self.norm_b, - alibi) + alibi, + **kwargs) presents = (key, value) self.layer_past = presents if layer_past is None else None output = self.mlp(attention_output, input, inp_norm, self.attention.attn_ob) if not self.config.pre_layer_norm: - output = inference_module.layer_norm(output, self.norm_w, self.norm_b, self.config.epsilon) + output = self.layer_norm(output, self.norm_w, self.norm_b, self.config.epsilon) output = output.to(input_type) if get_present: diff --git a/deepspeed/ops/transformer/inference/config.py b/deepspeed/ops/transformer/inference/config.py index 9709328cc133..c0dd29f4f962 100644 --- a/deepspeed/ops/transformer/inference/config.py +++ b/deepspeed/ops/transformer/inference/config.py @@ -103,7 +103,6 @@ def __init__(self, self.return_tuple = return_tuple self.mlp_after_attn = mlp_after_attn self.mlp_act_func_type = mlp_act_func_type - self.specialized_mode = False self.training_mp_size = training_mp_size self.bigscience_bloom = bigscience_bloom self.max_out_tokens = max_out_tokens diff --git a/deepspeed/ops/transformer/inference/diffusers_attention.py b/deepspeed/ops/transformer/inference/diffusers_attention.py index 5efc560db75e..3c2340ccfc6f 100644 --- a/deepspeed/ops/transformer/inference/diffusers_attention.py +++ b/deepspeed/ops/transformer/inference/diffusers_attention.py @@ -10,10 +10,11 @@ from packaging import version as pkg_version from deepspeed.utils.logging import log_dist from deepspeed.accelerator import get_accelerator -from deepspeed.ops.op_builder import InferenceBuilder +from deepspeed.ops.transformer.inference.op_binding.workspace import WorkspaceOp +from deepspeed.ops.transformer.inference.op_binding.softmax_context import SoftmaxContextOp +from deepspeed.ops.transformer.inference.op_binding import LinearOp +from deepspeed.ops.transformer.inference.op_binding.pad_transform import PadTransformOp -# Cuda modules will be imported if needed -inference_module = None minus_inf = -10000.0 triton_flash_attn = None @@ -36,7 +37,8 @@ class DeepSpeedDiffusersAttentionFunction(Function): @staticmethod def forward(ctx, input, context, input_mask, config, attn_qkvw, attn_qw, attn_kw, attn_vw, attn_qkvb, num_attention_heads_per_partition, norm_factor, hidden_size_per_partition, attn_ow, attn_ob, - do_out_bias, score_context_func, linear_func, triton_flash_attn_kernel, rope_theta): + do_out_bias, score_context_func, linear_func, pad_transform_func, triton_flash_attn_kernel, + rope_theta): def _transpose_for_context(x): x = x.permute(0, 2, 1, 3) @@ -77,7 +79,7 @@ def selfAttention_fp(input, context, input_mask): query = query.contiguous() key = key.contiguous() value = value.contiguous() - query, key, value = inference_module.pad_transform_fp16(query, key, value, config.heads, do_flash_attn) + query, key, value = pad_transform_func(query, key, value, config.heads, do_flash_attn) attention_scores = (torch.matmul(query, key.transpose(-1, -2)) * scale).softmax(dim=-1) context_layer = _transpose_for_context(torch.matmul(attention_scores, value)) @@ -117,10 +119,6 @@ def __init__( data_type = self.config.dtype data_type_fp = torch.half if self.config.dtype == torch.int8 else self.config.dtype - global inference_module - if inference_module is None: - builder = InferenceBuilder() - inference_module = builder.load() if DeepSpeedDiffusersAttention.layer_id == 1: log_dist(f"DeepSpeed-Attention config: {self.config.__dict__}", [0]) @@ -171,26 +169,24 @@ def __init__( self.norm_factor *= math.sqrt(self.config.layer_id + 1) # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/gpt2/modeling_gpt2.py#L191 - if self.config.dtype in [torch.float16, torch.int8]: - self.score_context_func = inference_module.softmax_context_fp16 - self.linear_func = inference_module.linear_layer_fp16 - self.allocate_workspace = inference_module.allocate_workspace_fp16 - else: - self.score_context_func = inference_module.softmax_context_fp32 - self.linear_func = inference_module.linear_layer_fp32 - self.allocate_workspace = inference_module.allocate_workspace_fp32 + self.workspace = WorkspaceOp(self.config) + self.score_context_func = SoftmaxContextOp(self.config) + self.linear_func = LinearOp(self.config) + self.pad_transform_func = PadTransformOp(self.config) - def forward(self, input, context=None, input_mask=None): + def allocate_workspace(self, size): + # Allocate memory only on first layer forward if self.config.layer_id == 0: - self.allocate_workspace(self.config.hidden_size, self.config.heads, - input.size()[1], - input.size()[0], DeepSpeedDiffusersAttention.layer_id, self.config.mp_size, False, - 0, self.config.max_out_tokens, self.config.min_out_tokens) - output = DeepSpeedDiffusersAttentionFunction.apply(input, context, input_mask, self.config, self.attn_qkvw, - self.attn_qw, self.attn_kw, self.attn_vw, self.attn_qkvb, - self.num_attention_heads_per_partition, self.norm_factor, - self.hidden_size_per_partition, self.attn_ow, self.attn_ob, - self.do_out_bias, self.score_context_func, self.linear_func, - self.triton_flash_attn_kernel, self.config.rope_theta) + self.workspace.allocate_workspace(self.config.hidden_size, self.config.heads, size[1], size[0], + DeepSpeedDiffusersAttention.layer_id, self.config.mp_size, False, 0, + self.config.max_out_tokens, self.config.min_out_tokens) + + def forward(self, input, context=None, input_mask=None): + self.allocate_workspace(input.size()) + output = DeepSpeedDiffusersAttentionFunction.apply( + input, context, input_mask, self.config, self.attn_qkvw, self.attn_qw, self.attn_kw, self.attn_vw, + self.attn_qkvb, self.num_attention_heads_per_partition, self.norm_factor, self.hidden_size_per_partition, + self.attn_ow, self.attn_ob, self.do_out_bias, self.score_context_func, self.linear_func, + self.pad_transform_func, self.triton_flash_attn_kernel, self.config.rope_theta) return output diff --git a/deepspeed/ops/transformer/inference/diffusers_transformer_block.py b/deepspeed/ops/transformer/inference/diffusers_transformer_block.py index b0156f905a06..d01638f36e40 100644 --- a/deepspeed/ops/transformer/inference/diffusers_transformer_block.py +++ b/deepspeed/ops/transformer/inference/diffusers_transformer_block.py @@ -10,26 +10,9 @@ from .diffusers_attention import DeepSpeedDiffusersAttention from .bias_add import nhwc_bias_add from .diffusers_2d_transformer import Diffusers2DTransformerConfig -from deepspeed.ops.op_builder import InferenceBuilder, SpatialInferenceBuilder from deepspeed.utils.types import ActivationFuncType - -# Ops will be loaded on demand -transformer_cuda_module = None -spatial_cuda_module = None - - -def load_transformer_module(): - global transformer_cuda_module - if transformer_cuda_module is None: - transformer_cuda_module = InferenceBuilder().load() - return transformer_cuda_module - - -def load_spatial_module(): - global spatial_cuda_module - if spatial_cuda_module is None: - spatial_cuda_module = SpatialInferenceBuilder().load() - return spatial_cuda_module +from .op_binding.gated_activation import GatedActivationOp +from .op_binding.layer_norm import LayerNormOp class DeepSpeedDiffusersTransformerBlock(nn.Module): @@ -76,8 +59,8 @@ def __init__(self, equivalent_module: nn.Module, config: Diffusers2DTransformerC else: self.attn_2_bias = nn.Paramaeter(torch.zeros_like(self.norm3_g), requires_grad=False) - self.transformer_cuda_module = load_transformer_module() - load_spatial_module() + self.gated_activation = GatedActivationOp() + self.layer_norm = LayerNormOp() def forward(self, hidden_states, context=None, timestep=None, **kwargs): # In v0.12.0 of diffuser, several new kwargs were added. Capturing @@ -88,17 +71,17 @@ def forward(self, hidden_states, context=None, timestep=None, **kwargs): if "encoder_hidden_states" in kwargs and kwargs["encoder_hidden_states"] is not None: context = kwargs["encoder_hidden_states"] - out_norm_1 = self.transformer_cuda_module.layer_norm(hidden_states, self.norm1_g, self.norm1_b, self.norm1_eps) + out_norm_1 = self.layer_norm(hidden_states, self.norm1_g, self.norm1_b, self.norm1_eps) out_attn_1 = self.attn_1(out_norm_1) - out_norm_2, out_attn_1 = self.transformer_cuda_module.layer_norm_residual_store_pre_ln_res( + out_norm_2, out_attn_1 = self.layer_norm.layer_norm_residual_store_pre_ln_res( out_attn_1, self.attn_1_bias, hidden_states, self.norm2_g, self.norm2_b, self.norm2_eps) out_attn_2 = self.attn_2(out_norm_2, context=context) - out_norm_3, out_attn_2 = self.transformer_cuda_module.layer_norm_residual_store_pre_ln_res( + out_norm_3, out_attn_2 = self.layer_norm.layer_norm_residual_store_pre_ln_res( out_attn_2, self.attn_2_bias, out_attn_1, self.norm3_g, self.norm3_b, self.norm3_eps) out_ff1 = nn.functional.linear(out_norm_3, self.ff1_w) - out_geglu = self.transformer_cuda_module.gated_activation(out_ff1, self.ff1_b, ActivationFuncType.GATED_GELU) + out_geglu = self.gated_activation(out_ff1, self.ff1_b, ActivationFuncType.GATED_GELU) out_ff2 = nn.functional.linear(out_geglu, self.ff2_w) return nhwc_bias_add(out_ff2, self.ff2_b, other=out_attn_2) diff --git a/deepspeed/ops/transformer/inference/ds_attention.py b/deepspeed/ops/transformer/inference/ds_attention.py index ffb58175daad..24f710d22494 100644 --- a/deepspeed/ops/transformer/inference/ds_attention.py +++ b/deepspeed/ops/transformer/inference/ds_attention.py @@ -89,7 +89,7 @@ def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count torch.empty(self.hidden_size_per_partition * 3, dtype=data_type_fp, device=device) ] - def compute_attention(self, qkv_out, input_mask, layer_past, alibi): + def compute_attention(self, qkv_out, input_mask, layer_past, alibi, is_prompt, token_idx, position_ids): if isinstance(qkv_out, list) or isinstance(qkv_out, tuple): qkv_out = qkv_out[0] @@ -108,7 +108,10 @@ def compute_attention(self, qkv_out, input_mask, layer_past, alibi): no_masking=no_masking, layer_id=self.config.layer_id, num_layers=DeepSpeedSelfAttention.num_layers, - alibi=alibi) + alibi=alibi, + is_prompt=is_prompt, + token_idx=token_idx, + position_ids=position_ids) context_layer, key_layer, value_layer = attn_key_value return context_layer, key_layer, value_layer @@ -136,7 +139,8 @@ def forward(self, output_attentions=False, norm_w=None, norm_b=None, - alibi=None): + alibi=None, + **kwargs): if self.attn_qkvw is None: self._attn_qkvw, self._attn_qkvb = self._merge_qkv() else: @@ -157,10 +161,17 @@ def forward(self, gamma=norm_w, beta=norm_b) + is_prompt = kwargs.get("first_token", qkv_out[0].shape[1] > 1) + token_idx = kwargs.get("token_idx", None) + position_ids = kwargs.get("position_ids", None) + context_layer, key_layer, value_layer = self.compute_attention(qkv_out=qkv_out, input_mask=input_mask, layer_past=layer_past, - alibi=alibi) + alibi=alibi, + is_prompt=is_prompt, + token_idx=token_idx, + position_ids=position_ids) output = self.vector_matmul_func(input=context_layer, weight=self.attn_ow) inp_norm = qkv_out[-1] @@ -210,7 +221,7 @@ def _split_tensor_along_last_dim(self, tensor, num_partitions, contiguous_split_ return tensor_list - def compute_attention(self, qkv_out, input_mask, layer_past, alibi): + def compute_attention(self, qkv_out, input_mask, layer_past, alibi, is_prompt, token_idx, position_ids): if isinstance(qkv_out, list) or isinstance(qkv_out, tuple): qkv_out = qkv_out[0] diff --git a/deepspeed/ops/transformer/inference/moe_inference.py b/deepspeed/ops/transformer/inference/moe_inference.py index fc001a86d42e..3a9785985d19 100644 --- a/deepspeed/ops/transformer/inference/moe_inference.py +++ b/deepspeed/ops/transformer/inference/moe_inference.py @@ -7,16 +7,16 @@ import math import torch from torch.autograd import Function -# accelerator modules will be imported if needed -inference_module = None -specialized_mode = None import torch.nn as nn from .ds_attention import DeepSpeedSelfAttention from .config import DeepSpeedInferenceConfig +from .op_binding import SoftmaxOp, VectorMatMulOp, GELUGemmOp +from .op_binding.bias_residual import BiasResidualOp +from .op_binding.einsum_sec_sm_ecm import EinsumSecSmEcmOp +from .op_binding.layer_norm import LayerNormOp from ....moe.sharded_moe import TopKGate from deepspeed import comm as dist -from deepspeed.accelerator import get_accelerator -from deepspeed.ops.op_builder import InferenceBuilder +from .op_binding.moe_res_matmul import MoEResMatmulOp class DeepSpeedMoEInferenceConfig(DeepSpeedInferenceConfig): @@ -110,16 +110,13 @@ class DeepSpeedMLPFunction(Function): @staticmethod def forward(ctx, input, inter_w, inter_b, config, output_b, output_w, q_scales, q_groups, merge_count, mp_group, - async_op): + async_op, gelu_gemm_func, vector_matmul_func): if config.q_int8: - intermediate = inference_module.fused_gemm_gelu_int8(input, inter_w, inter_b, config.epsilon, q_scales[2], - (q_groups * (2**merge_count)), config.pre_layer_norm) - output = inference_module.vector_matmul_int8(intermediate, output_w, q_scales[3], q_groups, (merge_count)) + intermediate = gelu_gemm_func(input, inter_w, inter_b, config.epsilon, q_scales[2], + (q_groups * (2**merge_count)), config.pre_layer_norm) + output = vector_matmul_func(intermediate, output_w, q_scales[3], q_groups, (merge_count)) else: - mlp_gemm_func = inference_module.fused_gemm_gelu_fp16 if config.fp16 else \ - inference_module.fused_gemm_gelu_fp32 - - output = mlp_gemm_func(input, inter_w, inter_b, output_w, config.epsilon, config.pre_layer_norm, async_op) + output = gelu_gemm_func(input, inter_w, inter_b, output_w, config.epsilon, config.pre_layer_norm, async_op) if mp_group is not None and dist.get_world_size(group=mp_group) > 1: dist.all_reduce(output, group=mp_group, async_op=async_op) @@ -150,10 +147,13 @@ def __init__(self, config, q_scales=None, q_groups=1, merge_count=1, mlp_extra_g self.q_groups = q_groups * 2 if mlp_extra_grouping else q_groups self.merge_count = int(math.log2(merge_count)) self.mp_group = mp_group + self.gelu_gemm_func = GELUGemmOp(self.config) + self.vector_matmul_func = VectorMatMulOp(self.config) def forward(self, input, async_op=False): return DeepSpeedMLPFunction.apply(input, self.inter_w, self.inter_b, self.config, self.output_b, self.output_w, - self.q_scales, self.q_groups, self.merge_count, self.mp_group, async_op) + self.q_scales, self.q_groups, self.merge_count, self.mp_group, async_op, + self.gelu_gemm_func, self.vector_matmul_func) class DeepSpeedMoEInference(nn.Module): @@ -187,18 +187,7 @@ def __init__(self, self.config = config self.config.layer_id = DeepSpeedMoEInference.layer_id - global inference_module - global specialized_mode - if inference_module is None: - specialized_mode = False - # InferenceSpecializedBuilder is not among DeepSpeed provided builder yet, so we infer by builder name string - builder = get_accelerator().create_op_builder("InferenceSpecializedBuilder") - if builder is not None and builder.is_compatible(): - inference_module = builder.load() - specialized_mode = True - else: - inference_module = InferenceBuilder().load() - self.config.specialized_mode = specialized_mode + assert self.config.dtype != torch.bfloat16, "DeepSpeed MoE Transformer Inference not yet tested for bfloat support" DeepSpeedMoEInference.layer_id += 1 @@ -213,10 +202,8 @@ def __init__(self, self.res_mlp = DeepSpeedMoEMLP(config, quantize_scales, quantize_groups, merge_count, mlp_extra_grouping, mp_group) self.res_coef = nn.Parameter(torch.Tensor(self.config.hidden_size, 2)) - self.coef_func = inference_module.softmax_fp16 if self.config.dtype in [torch.float16, torch.int8] else \ - inference_module.softmax_fp32 - self.vector_matmul_func = inference_module.vector_matmul_fp16 if self.config.dtype == torch.float16 else \ - inference_module.vector_matmul_fp32 + self.coef_func = SoftmaxOp(self.config) + self.vector_matmul_func = VectorMatMulOp(self.config) config.mp_size = 1 self.mlp = nn.ModuleList( @@ -234,12 +221,10 @@ def __init__(self, print("DeepSpeed MoE Transformer Inference config is ", self.config.__dict__) - self.bias_residual_func = inference_module.bias_residual_fp16 if self.config.dtype in [torch.float16, torch.int8] else \ - inference_module.bias_residual_fp32 - self.ds_layernorm = inference_module.layer_norm_fp16 if self.config.dtype in [torch.float16, torch.int8] else \ - inference_module.layer_norm_fp32 - self.einsum_sec_sm_ecm = inference_module.einsum_sec_sm_ecm_fp16 if self.config.dtype in [torch.float16, torch.int8] else \ - inference_module.einsum_sec_sm_ecm_fp32 + self.bias_residual_func = BiasResidualOp(self.config) + self.ds_layernorm = LayerNormOp(self.config) + self.einsum_sec_sm_ecm = EinsumSecSmEcmOp(self.config) + self.moe_res_matmul = MoEResMatmulOp(self.config) def res_coef_func(self, inp, async_op): inp = self.vector_matmul_func(inp, self.res_coef, async_op) @@ -346,7 +331,7 @@ def forward(self, dim=0)[dist.get_rank(group=self.expert_mp_group)] if self.config.mlp_type == 'residual': - inference_module.moe_res_matmul(res_mlp_out, res_coef_out, output) + self.moe_res_matmul(res_mlp_out, res_coef_out, output) output = self.bias_residual_func(output, residual_add, torch.empty(1)) diff --git a/deepspeed/ops/transformer/inference/op_binding/bias_add.py b/deepspeed/ops/transformer/inference/op_binding/bias_add.py new file mode 100644 index 000000000000..d2ae38f546eb --- /dev/null +++ b/deepspeed/ops/transformer/inference/op_binding/bias_add.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from ..config import DeepSpeedInferenceConfig +from .base import BaseOp + + +class BiasAddOp(BaseOp): + + def __init__(self, config: DeepSpeedInferenceConfig): + super(BiasAddOp, self).__init__(config) + + try: + if self.config.dtype == torch.float16: + self.bias_add_func = self.inference_module.bias_add_fp16 + elif self.config.dtype == torch.bfloat16: + self.bias_add_func = self.inference_module.bias_add_bf16 + else: + self.bias_add_func = self.inference_module.bias_add_fp32 + except AttributeError: + self.bias_add_func = self.bias_add_fallback + + @classmethod + def bias_add_fallback(cls, input, bias): + return torch.add(input, bias) + + def forward(self, activation: torch.Tensor, bias: torch.Tensor): + return self.bias_add_func(activation, bias) diff --git a/deepspeed/ops/transformer/inference/op_binding/bias_gelu.py b/deepspeed/ops/transformer/inference/op_binding/bias_gelu.py new file mode 100644 index 000000000000..f0fee0b0d06e --- /dev/null +++ b/deepspeed/ops/transformer/inference/op_binding/bias_gelu.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import torch.nn.functional as F +from ..config import DeepSpeedInferenceConfig +from .base import BaseOp + + +class BiasGeluOp(BaseOp): + + def __init__(self, config: DeepSpeedInferenceConfig): + super(BiasGeluOp, self).__init__(config) + + try: + if self.config.dtype == torch.float16: + self.bias_gelu_func = self.inference_module.bias_gelu_fp16 + elif self.config.dtype == torch.bfloat16: + self.bias_gelu_func = self.inference_module.bias_gelu_bf16 + else: + self.bias_gelu_func = self.inference_module.bias_gelu_fp32 + except AttributeError: + self.bias_gelu_func = self.bias_gelu_fallback + + @classmethod + def bias_gelu_fallback(cls, activations, bias): + # Expected behavior is that of casting to float32 internally and using the tanh approximation + return F.gelu(activations.to(torch.float32) + bias.to(torch.float32), approximate='tanh').to(activations.dtype) + + def forward(self, activation: torch.Tensor, bias: torch.Tensor): + return self.bias_gelu_func(activation, bias) diff --git a/deepspeed/ops/transformer/inference/op_binding/bias_relu.py b/deepspeed/ops/transformer/inference/op_binding/bias_relu.py new file mode 100644 index 000000000000..ccfade1d9524 --- /dev/null +++ b/deepspeed/ops/transformer/inference/op_binding/bias_relu.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import torch.nn.functional as F +from ..config import DeepSpeedInferenceConfig +from .base import BaseOp + + +class BiasReluOp(BaseOp): + + def __init__(self, config: DeepSpeedInferenceConfig): + super(BiasReluOp, self).__init__(config) + + try: + if self.config.dtype == torch.float16: + self.bias_relu_func = self.inference_module.bias_relu_fp16 + elif self.config.dtype == torch.bfloat16: + self.bias_relu_func = self.inference_module.bias_relu_bf16 + else: + self.bias_relu_func = self.inference_module.bias_relu_fp32 + except AttributeError: + self.bias_relu_func = self.bias_relu_fallback + + @classmethod + def bias_relu_fallback(cls, activations, bias): + # Expected behavior is that of casting to float32 internally + return F.relu(activations.to(torch.float32) + bias.to(torch.float32)).to(activations.dtype) + + def forward(self, activation: torch.Tensor, bias: torch.Tensor): + return self.bias_relu_func(activation, bias) diff --git a/deepspeed/ops/transformer/inference/op_binding/bias_residual.py b/deepspeed/ops/transformer/inference/op_binding/bias_residual.py new file mode 100644 index 000000000000..ecad50e10ffe --- /dev/null +++ b/deepspeed/ops/transformer/inference/op_binding/bias_residual.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from ..config import DeepSpeedInferenceConfig +from .base import BaseOp + + +class BiasResidualOp(BaseOp): + + def __init__(self, config: DeepSpeedInferenceConfig): + super(BiasResidualOp, self).__init__(config) + + try: + if self.config.dtype in [torch.float16, torch.int8]: + self.bias_residual_func = self.inference_module.bias_residual_fp16 + else: + self.bias_residual_func = self.inference_module.bias_residual_fp32 + except AttributeError: + self.bias_residual_func = self.bias_residual_fallback + + @classmethod + def bias_residual_fallback(cls, output, residual, bias): + raise NotImplementedError("bias residual fallback isn't implemented") + + def forward(self, output, residual, bias): + return self.bias_residual_func(output, residual, bias) diff --git a/deepspeed/ops/transformer/inference/op_binding/einsum_sec_sm_ecm.py b/deepspeed/ops/transformer/inference/op_binding/einsum_sec_sm_ecm.py new file mode 100644 index 000000000000..f34b10f786d1 --- /dev/null +++ b/deepspeed/ops/transformer/inference/op_binding/einsum_sec_sm_ecm.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from ..config import DeepSpeedInferenceConfig +from .base import BaseOp + + +class EinsumSecSmEcmOp(BaseOp): + + def __init__(self, config: DeepSpeedInferenceConfig): + super(EinsumSecSmEcmOp, self).__init__(config) + + try: + if self.config.dtype in [torch.float16, torch.int8]: + self.einsum_sec_sm_ecm_func = self.inference_module.einsum_sec_sm_ecm_fp16 + else: + self.einsum_sec_sm_ecm_func = self.inference_module.einsum_sec_sm_ecm_fp32 + except AttributeError: + self.einsum_sec_sm_ecm_func = self.einsum_sec_sm_ecm_fallback + + @classmethod + def einsum_sec_sm_ecm_fallback(cls, Q, W): + raise NotImplementedError("einsum sec sm ecm fallback isn't implemented") + + def forward(self, Q, W): + return self.einsum_sec_sm_ecm_func(Q, W) diff --git a/deepspeed/ops/transformer/inference/op_binding/gated_activation.py b/deepspeed/ops/transformer/inference/op_binding/gated_activation.py new file mode 100644 index 000000000000..d28d818ce4b3 --- /dev/null +++ b/deepspeed/ops/transformer/inference/op_binding/gated_activation.py @@ -0,0 +1,40 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import torch.nn.functional as F +from deepspeed.utils.types import ActivationFuncType +from ..config import DeepSpeedInferenceConfig +from .base import BaseOp + + +class GatedActivationOp(BaseOp): + + def __init__(self, config: DeepSpeedInferenceConfig = None): + if config is None: + config = DeepSpeedInferenceConfig() + super(GatedActivationOp, self).__init__(config) + try: + self.gated_activation_func = self.inference_module.gated_activation + except AttributeError: + self.gated_activation_func = self.gated_activation_fallback + + @classmethod + def gated_activation_fallback(cls, activation, bias, activation_func_type): + # Expected behavior is that of casting to float32 internally + # Explicitly using the default GeLU + activation_func = None + activations = activation + bias.reshape(1, 1, -1) + hidden_states, gate = activations.chunk(2, dim=-1) + + if activation_func_type == ActivationFuncType.GATED_SILU: + activation_func = F.silu + elif activation_func_type == ActivationFuncType.GATED_GELU: + activation_func = F.gelu + + return hidden_states * activation_func(gate.to(torch.float32)).to(activations.dtype) + + def forward(self, activation: torch.Tensor, bias: torch.Tensor, activation_func_type: ActivationFuncType): + return self.gated_activation_func(activation, bias, activation_func_type) diff --git a/deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py b/deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py index 63323c150752..60bbb4b48bdb 100644 --- a/deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py +++ b/deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py @@ -4,6 +4,7 @@ # DeepSpeed Team import torch +import torch.nn.functional as F from ..config import DeepSpeedInferenceConfig from .base import BaseOp import deepspeed @@ -14,7 +15,9 @@ class GELUGemmOp(BaseOp): def __init__(self, config: DeepSpeedInferenceConfig): super(GELUGemmOp, self).__init__(config) try: - if self.config.dtype in [torch.float16, torch.int8]: + if self.config.dtype == torch.int8: + self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_int8 + elif self.config.dtype == torch.float16: if deepspeed.HAS_TRITON and self.config.use_triton and self.config.dtype == torch.float16: from deepspeed.ops.transformer.inference.triton.ops import fused_gemm_gelu as _triton_fused_gemm_gelu self.fused_gemm_gelu = _triton_fused_gemm_gelu # type: ignore @@ -28,7 +31,11 @@ def __init__(self, config: DeepSpeedInferenceConfig): self.fused_gemm_gelu = self.gelu_gemm_fallback def gelu_gemm_fallback(self, input, weight, scale, bias, out, out_scale, dtype, transpose): - raise NotImplementedError + tmp = torch.matmul(input, weight) + tmp = F.gelu(tmp.to(torch.float32) + bias.to(torch.float32), approximate="tanh").to(tmp.dtype) + output = torch.matmul(tmp, out) + + return output def forward(self, input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, weight_out: torch.Tensor): diff --git a/deepspeed/ops/transformer/inference/op_binding/layer_norm.py b/deepspeed/ops/transformer/inference/op_binding/layer_norm.py new file mode 100644 index 000000000000..31219a58ac3c --- /dev/null +++ b/deepspeed/ops/transformer/inference/op_binding/layer_norm.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import torch.nn.functional as F +from ..config import DeepSpeedInferenceConfig +from .base import BaseOp + + +class LayerNormOp(BaseOp): + + def __init__(self, config: DeepSpeedInferenceConfig = None): + super(LayerNormOp, self).__init__(config) + try: + if config is None: + self.layer_norm_func = self.inference_module.layer_norm + elif self.config.dtype in [torch.float16, torch.int8]: + self.layer_norm_func = self.inference_module.layer_norm_fp16 + else: + self.layer_norm_func = self.inference_module.layer_norm_fp32 + except AttributeError: + self.layer_norm_func = self.layer_norm_fallback + + @classmethod + def layer_norm_residual(cls, vals, bias, res, gamma, beta, epsilon): + channels = gamma.shape[0] + dtype = gamma.dtype + vals_f = vals.to(torch.float32) + bias_f = bias.to(torch.float32).reshape(1, 1, -1) + res_f = res.to(torch.float32) + gamma_f = gamma.to(torch.float32) + beta_f = beta.to(torch.float32) + return F.layer_norm(vals_f + bias_f + res_f, (channels, ), weight=gamma_f, bias=beta_f, eps=epsilon).to(dtype) + + @classmethod + def layer_norm_residual_store_pre_ln_res(cls, vals, bias, res, gamma, beta, epsilon): + channels = gamma.shape[0] + dtype = gamma.dtype + vals_f = vals.to(torch.float32) + bias_f = bias.to(torch.float32).reshape(1, 1, -1) + res_f = res.to(torch.float32) + gamma_f = gamma.to(torch.float32) + beta_f = beta.to(torch.float32) + res_output = vals_f + bias_f + res_f + norm_output = F.layer_norm(res_output, (channels, ), weight=gamma_f, bias=beta_f, eps=epsilon).to(dtype) + return norm_output, res_output.to(dtype) + + @classmethod + def layer_norm_fallback(cls, vals, gamma, beta, epsilon): + channels = gamma.shape[0] + dtype = gamma.dtype + vals_f = vals.to(torch.float32) + gamma_f = gamma.to(torch.float32) + beta_f = beta.to(torch.float32) + return F.layer_norm(vals_f, (channels, ), weight=gamma_f, bias=beta_f, eps=epsilon).to(dtype) + + def forward(self, vals, gamma, beta, epsilon): + return self.layer_norm_func(vals, gamma, beta, epsilon) diff --git a/deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py b/deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py index 3064c00d1755..5f1f915ec021 100644 --- a/deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py +++ b/deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py @@ -5,12 +5,12 @@ from typing import Optional -import os import torch import torch.nn.functional as F from ..config import DeepSpeedInferenceConfig from .base import BaseOp from deepspeed.utils.types import NormType +from .pre_rms_norm import PreRMSNormOp class MLPGemmOp(BaseOp): @@ -39,23 +39,45 @@ def __init__(self, config: DeepSpeedInferenceConfig): self.mlp_gemm_func = self.mlp_gemm_fallback elif self.config.norm_type == NormType.RMSNorm: self.mlp_gemm_func = self.rms_mlp_gemm_fallback + self.pre_rms_norm = PreRMSNormOp() def mlp_gemm_fallback(self, input, residual, input_bias, weight_interm, weight_out, bias, gamma, beta, eps, pre_layer_norm, mlp_after_attn, interm_scale, out_scale, dtype, mlp_act_func_type, transpose): - if os.environ.get('DS_KI_FALLBACK') == 'True' and mlp_after_attn and not transpose: - residual_add = F.layer_norm(input + residual + input_bias, (input.shape[2], ), gamma, beta, - self.config.epsilon) - tmp = torch.matmul(residual_add, weight_interm) + if mlp_after_attn: + residual_add = F.layer_norm(input + residual + input_bias, (input.shape[2], ), gamma, beta, eps) + tmp = torch.matmul(residual_add, weight_interm.t() if transpose else weight_interm) tmp = F.gelu(tmp + bias) - output = torch.matmul(tmp, weight_out) - return (output, residual_add) + output = torch.matmul(tmp, weight_out.t() if transpose else weight_out) + + return output, residual_add else: raise NotImplementedError def rms_mlp_gemm_fallback(self, input, residual, weight_interm, weight_out, gamma, eps, interm_scale, out_scale, dtype, mlp_act_func_type, transpose): - raise NotImplementedError + inp_norm, residual = self.pre_rms_norm(input, residual, gamma, eps) + tmp = torch.matmul(inp_norm.view([-1, inp_norm.size(2)]), weight_interm.t() if transpose else weight_interm) + up_proj, gate_proj = tmp.chunk(2, dim=1) + + from deepspeed.utils.types import ActivationFuncType + if mlp_act_func_type == ActivationFuncType.GELU: + intermediate = F.gelu(gate_proj) + elif mlp_act_func_type == ActivationFuncType.ReLU: + intermediate = F.relu(gate_proj) + elif mlp_act_func_type == ActivationFuncType.GATED_GELU: + intermediate = F.gelu(gate_proj) + elif mlp_act_func_type == ActivationFuncType.GATED_SILU: + intermediate = F.silu(gate_proj) + else: + raise f"rms_mlp_gemm_fallback not implemented for activation type {mlp_act_func_type}" + + intermediate = intermediate * up_proj + + output = torch.matmul(intermediate, weight_out.t() if transpose else weight_out) + output = output.view([input.size(0), input.size(1), -1]) + + return [output, residual] def forward(self, input: torch.Tensor, diff --git a/deepspeed/ops/transformer/inference/op_binding/moe_res_matmul.py b/deepspeed/ops/transformer/inference/op_binding/moe_res_matmul.py new file mode 100644 index 000000000000..ef3558c8bc88 --- /dev/null +++ b/deepspeed/ops/transformer/inference/op_binding/moe_res_matmul.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from ..config import DeepSpeedInferenceConfig +from .base import BaseOp + + +class MoEResMatmulOp(BaseOp): + + def __init__(self, config: DeepSpeedInferenceConfig = None): + if config is None: + config = DeepSpeedInferenceConfig() + super(MoEResMatmulOp, self).__init__(config) + try: + self.moe_res_matmul_func = self.inference_module.moe_res_matmul + except AttributeError: + self.moe_res_matmul_func = self.moe_res_matmul_fallback + + @classmethod + def moe_res_matmul_fallback(cls, residual, coef, output): + coef_t = coef.transpose(1, 2).contiguous() + coef1, coef2 = torch.split(coef_t, split_size_or_sections=coef_t.shape[len(coef_t.shape) - 1] // 2, dim=-1) + return residual * coef1 + output * coef2 + + def forward(self, residual, coef, output): + return self.moe_res_matmul_func(residual, coef, output) diff --git a/deepspeed/ops/transformer/inference/op_binding/pad_transform.py b/deepspeed/ops/transformer/inference/op_binding/pad_transform.py new file mode 100644 index 000000000000..876fefc3bcfb --- /dev/null +++ b/deepspeed/ops/transformer/inference/op_binding/pad_transform.py @@ -0,0 +1,26 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from ..config import DeepSpeedInferenceConfig +from .base import BaseOp + + +class PadTransformOp(BaseOp): + + def __init__(self, config: DeepSpeedInferenceConfig = None): + if config is None: + config = DeepSpeedInferenceConfig() + super(PadTransformOp, self).__init__(config) + try: + self.pad_transform_func = self.inference_module.pad_transform_fp16 + except AttributeError: + self.pad_transform_func = self.pad_transform_fallback + + @staticmethod + def pad_transform_fallback(query, key, value, heads, do_flash_attn): + raise NotImplementedError("pad_transform fallback is not implemented.") + + def forward(self, query, key, value, heads, do_flash_attn): + return self.pad_transform_func(query, key, value, heads, do_flash_attn) diff --git a/deepspeed/ops/transformer/inference/op_binding/pre_rms_norm.py b/deepspeed/ops/transformer/inference/op_binding/pre_rms_norm.py new file mode 100644 index 000000000000..7969d20f0527 --- /dev/null +++ b/deepspeed/ops/transformer/inference/op_binding/pre_rms_norm.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from ..config import DeepSpeedInferenceConfig +from .base import BaseOp +from .rms_norm import RMSNormOp + + +class PreRMSNormOp(BaseOp): + + def __init__(self, config: DeepSpeedInferenceConfig = None): + if config is None: + config = DeepSpeedInferenceConfig() + super(PreRMSNormOp, self).__init__(config) + try: + self.pre_rms_norm_func = self.inference_module.pre_rms_norm + except AttributeError: + self.pre_rms_norm_func = self.pre_rms_norm_fallback + + @staticmethod + def pre_rms_norm_fallback(vals, residual, gamma, epsilon): + residual = vals.to(torch.float32) + residual.to(torch.float32) + vals = residual + + return RMSNormOp.rms_norm_fallback(vals, gamma, epsilon), residual.to(gamma.dtype) + + def forward(self, vals, residual, gamma, epsilon): + return self.pre_rms_norm_func(vals, residual, gamma, epsilon) diff --git a/deepspeed/ops/transformer/inference/op_binding/qkv_gemm.py b/deepspeed/ops/transformer/inference/op_binding/qkv_gemm.py index 250bf9864e1e..9ff5366fae5d 100644 --- a/deepspeed/ops/transformer/inference/op_binding/qkv_gemm.py +++ b/deepspeed/ops/transformer/inference/op_binding/qkv_gemm.py @@ -3,11 +3,11 @@ # DeepSpeed Team -import os import torch import torch.nn.functional as F from ..config import DeepSpeedInferenceConfig from .base import BaseOp +from .rms_norm import RMSNormOp import deepspeed from deepspeed.utils.types import NormType @@ -56,19 +56,23 @@ def _triton_autotune(min_seqlen, max_seqlen, hidden_size, dtype=torch.float16): matmul(A, B) Fp16Matmul._update_autotune_table() - def qkv_gemm_fallback(self, input, weight, q_scale, bias, gamma, beta, eps, add_bias, q_int8, transpose): - if os.environ.get('DS_KI_FALLBACK') == 'True' and not transpose: - inp_norm = F.layer_norm(input, (input.shape[2], ), gamma, beta, eps) - tmp = torch.matmul(inp_norm, weight) - if add_bias: - tmp += bias - output = [tmp, inp_norm] - return output - else: - raise NotImplementedError + @staticmethod + def qkv_gemm_fallback(input, weight, q_scale, bias, gamma, beta, eps, add_bias, q_int8, transpose): + inp_norm = F.layer_norm(input, (input.shape[2], ), gamma, beta, eps) + tmp = torch.matmul(inp_norm, weight.t() if transpose else weight) + if add_bias: + tmp += bias + output = [tmp, inp_norm] + + return output + + @staticmethod + def rms_qkv_gemm_fallback(input, weight, q_scale, gamma, eps, q_int8, transpose): + inp_norm = RMSNormOp.rms_norm_fallback(input, gamma, eps) + tmp = torch.matmul(inp_norm, weight.t() if transpose else weight) + output = [tmp, inp_norm] - def rms_qkv_gemm_fallback(self, input, weight, q_scale, gamma, eps, q_int8, transpose): - raise NotImplementedError + return output def forward(self, input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor): diff --git a/deepspeed/ops/transformer/inference/op_binding/residual_add.py b/deepspeed/ops/transformer/inference/op_binding/residual_add.py index 6f9b35cbc05d..93b229c5d1ac 100644 --- a/deepspeed/ops/transformer/inference/op_binding/residual_add.py +++ b/deepspeed/ops/transformer/inference/op_binding/residual_add.py @@ -3,9 +3,10 @@ # DeepSpeed Team -import os import torch from typing import Optional + +from .vector_add import VectorAddOp from ..config import DeepSpeedInferenceConfig from .base import BaseOp @@ -22,11 +23,32 @@ def __init__(self, config: DeepSpeedInferenceConfig): else: self.residual_add_func = self.inference_module.residual_add_bias_fp32 except AttributeError: - self.residual_add_func = None - try: - self._vector_add = self.inference_module._vector_add - except AttributeError: - self._vector_add = None + self.residual_add_func = self.residual_add_fallback + self.vector_add = VectorAddOp() + + @staticmethod + def res_add_bias(hidden_state, residual, attn_output, attn_bias, final_bias, add_attn_bias, mp_size): + hidden_state += attn_output + (residual + final_bias) / mp_size + if add_attn_bias: + hidden_state += attn_bias / mp_size + + return hidden_state + + @staticmethod + def residual_add_fallback(hidden_state, residual, attention_output, attention_bias, final_bias, mp_size, + mlp_after_attn, add_bias, pre_layer_norm): + if mlp_after_attn: + if pre_layer_norm: + tmp = (residual.float() + attention_output.float() + attention_bias.float() + + final_bias.float()) / mp_size + hidden_state.float() + else: + tmp = residual.float() + hidden_state.float() + final_bias.float() + else: + tmp = ResidualAddOp.res_add_bias(hidden_state, residual, attention_output, attention_bias, final_bias, + add_bias, mp_size) + residual.copy_(tmp.to(hidden_state.dtype)) + + return residual def forward(self, hidden_state: torch.Tensor, @@ -37,28 +59,15 @@ def forward(self, attention_bias: Optional[torch.Tensor] = None, final_bias: Optional[torch.Tensor] = None): - if self.residual_add_func is not None: - if final_bias is None: - residual = self._vector_add(residual, hidden_state, 1.0 / self.config.mp_size) - else: - if not self.config.pre_layer_norm and residual_add is not None: - # only use residual add if its set and we are not pre layer norm - residual = residual_add - - self.residual_add_func(hidden_state, residual, attention_output, attention_bias, final_bias, - self.config.mp_size, self.config.mlp_after_attn, add_bias, - self.config.pre_layer_norm) + if final_bias is None and attention_bias is None: + residual = self.vector_add(residual + attention_output, hidden_state, 1.0 / self.config.mp_size) else: - # fallback - if os.environ.get('DS_KI_FALLBACK') == 'True' and self.config.mlp_after_attn: - if self.config.pre_layer_norm: - tmp = (residual.float() + attention_output.float() + attention_bias.float() + - final_bias.float()) / self.config.mp_size + hidden_state.float() - else: - tmp = residual.float() + hidden_state.float() + final_bias.float() + if not self.config.pre_layer_norm and residual_add is not None: + # only use residual add if its set and we are not pre layer norm + residual = residual_add + + self.residual_add_func(hidden_state, residual, attention_output, attention_bias, final_bias, + self.config.mp_size, self.config.mlp_after_attn, add_bias, + self.config.pre_layer_norm) - input_dtype = hidden_state.dtype - residual = tmp.to(input_dtype) - else: - raise NotImplementedError return residual diff --git a/deepspeed/ops/transformer/inference/op_binding/rms_norm.py b/deepspeed/ops/transformer/inference/op_binding/rms_norm.py new file mode 100644 index 000000000000..128883ce5d43 --- /dev/null +++ b/deepspeed/ops/transformer/inference/op_binding/rms_norm.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from ..config import DeepSpeedInferenceConfig +from .base import BaseOp + + +class RMSNormOp(BaseOp): + + def __init__(self, config: DeepSpeedInferenceConfig = None): + if config is None: + config = DeepSpeedInferenceConfig() + super(RMSNormOp, self).__init__(config) + try: + self.rms_norm_func = self.inference_module.rms_norm + except AttributeError: + self.rms_norm_func = self.rms_norm_fallback + + @staticmethod + def rms_norm_fallback(vals, gamma, epsilon): + variance = vals.to(torch.float32).pow(2).mean(-1, keepdim=True) + vals = vals * torch.rsqrt(variance + epsilon) + + if gamma.dtype in [torch.float16, torch.bfloat16]: + vals = vals.to(gamma.dtype) + + return gamma * vals + + def forward(self, vals, gamma, epsilon): + return self.rms_norm_func(vals, gamma, epsilon) diff --git a/deepspeed/ops/transformer/inference/op_binding/softmax.py b/deepspeed/ops/transformer/inference/op_binding/softmax.py index bc309d94df14..2e08541596fa 100644 --- a/deepspeed/ops/transformer/inference/op_binding/softmax.py +++ b/deepspeed/ops/transformer/inference/op_binding/softmax.py @@ -3,11 +3,11 @@ # DeepSpeed Team -import os import torch import torch.nn.functional as F from ..config import DeepSpeedInferenceConfig from .base import BaseOp +from deepspeed.ops.transformer.inference.op_binding.workspace import InferenceContext class SoftmaxOp(BaseOp): @@ -25,24 +25,42 @@ def __init__(self, config: DeepSpeedInferenceConfig): except AttributeError: self.softmax_func = self.softmax_fallback - def softmax_fallback(self, attn_scores, attn_mask, alibi, triangular, recompute, local_attention, window_size, - async_op, layer_scale, head_offset, mp_size): - if os.environ.get('DS_KI_FALLBACK') == 'True': - alibi = alibi[head_offset:head_offset + self.num_attention_heads_per_partition] - input_dtype = attn_scores.dtype - if (triangular): - tri = ~torch.tril(torch.ones(attn_scores.size(), device=attn_scores.device)).to(bool) - attn_scores = torch.masked_fill(attn_scores * layer_scale, tri, torch.finfo(input_dtype).min) - if alibi is not None: - attn_scores += alibi - if attn_mask is not None: - # expand atten_mask from two dim into 4 dim, insert two dims in the middle + @staticmethod + def softmax_fallback(attn_scores, attn_mask, alibi, triangular, recompute, local_attention, window_size, async_op, + layer_scale, head_offset, mp_size): + scores_len = len(attn_scores.size()) + heads = 1 + if scores_len > 1: + heads = attn_scores.size()[1] + num_attention_heads_per_partition = heads // mp_size + + if alibi is not None: + if len(alibi.shape) == 1: + alibi = None + else: + alibi = alibi[head_offset:head_offset + num_attention_heads_per_partition] + if attn_mask is not None and len(attn_mask.shape) == 1: + attn_mask = None + input_dtype = attn_scores.dtype + attn_scores *= layer_scale + + if alibi is not None: + attn_scores += alibi + if attn_mask is not None: + # expand atten_mask from two dim into 4 dim, insert two dims in the middle + if len(attn_mask.shape) == 2: attn_mask = attn_mask[:, None, None, :] - attn_scores += attn_mask - output = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(input_dtype) - return output - else: - raise NotImplementedError + attn_scores += attn_mask + if triangular: + if attn_scores.shape[2] == 1: # query using kv cache + token_idx = InferenceContext.Instance().current_tokens() + tri = torch.arange(attn_scores.shape[2], device=attn_scores.device).ge(token_idx) + else: + tri = ~torch.tril(torch.ones(attn_scores.size(), device=attn_scores.device)).to(bool) + attn_scores = torch.masked_fill(attn_scores, tri, float('-inf')) + output = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(input_dtype) + + return output def forward(self, attn_scores: torch.Tensor, attn_mask: torch.Tensor, alibi: torch.Tensor, triangular: bool, recompute: bool, local_attention: bool, window_size: int, async_op: bool, layer_scale: float, diff --git a/deepspeed/ops/transformer/inference/op_binding/softmax_context.py b/deepspeed/ops/transformer/inference/op_binding/softmax_context.py index 0dc4e08a3633..d745df678e93 100644 --- a/deepspeed/ops/transformer/inference/op_binding/softmax_context.py +++ b/deepspeed/ops/transformer/inference/op_binding/softmax_context.py @@ -7,6 +7,8 @@ from deepspeed import comm as dist from ..config import DeepSpeedInferenceConfig from .base import BaseOp +from .softmax import SoftmaxOp +from deepspeed.ops.transformer.inference.op_binding.workspace import InferenceContext class SoftmaxContextOp(BaseOp): @@ -23,13 +25,108 @@ def __init__(self, config: DeepSpeedInferenceConfig): except AttributeError: self.softmax_context_func = self.softmax_context_fallback + @staticmethod + def transform4d_0213(x, seq_length): + assert x.dim() == 3, F"Dim {x.dim()} is not supported" + batch_size, num_heads, seq_length_head_dim = x.shape + head_dim = seq_length_head_dim // seq_length + x = x.view(batch_size, num_heads, seq_length, head_dim) + x = x.permute(0, 2, 1, 3) + + return x + + @staticmethod + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep <= 1 or num_key_value_heads == 1: + return hidden_states + + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + @staticmethod + def bias_add_transform_0213(input, bias, num_heads, trans_count, perform_bias=False): + assert trans_count == 1 or trans_count == 3, F"Trans count {trans_count} is not supported" + assert input.dim() == 3, F"Dim {input.dim()} is not supported" + input_biased = torch.add(input, bias) if perform_bias else input + batch_size, seq_length, value_size = input_biased.shape + hid_dim = value_size // trans_count + head_dim = hid_dim // num_heads + + if trans_count == 1: + query_layer = input.view(batch_size, seq_length, num_heads, head_dim) + query_layer = query_layer.permute(0, 2, 1, 3) + key_layer = torch.zeros_like(query_layer) + value_layer = torch.zeros_like(query_layer) + return query_layer, key_layer, value_layer + + qkv_layers = input.view(batch_size, seq_length, 3, num_heads, head_dim) + query_layer, key_layer, value_layer = qkv_layers[..., 0, :, :], qkv_layers[..., 1, :, :], qkv_layers[..., + 2, :, :] + query_layer = query_layer.transpose(1, 2) + key_layer = key_layer.transpose(1, 2) + value_layer = value_layer.transpose(1, 2) + + return query_layer, key_layer, value_layer + def softmax_context_fallback(self, query_key_value, attn_mask, rotary_dim, rotate_half, rotate_every_two, heads, num_kv, norm_factor, triangular_masking, local_attention, window_size, no_masking, - layer_id, num_layers, alibi, rope_theta): - raise NotImplementedError + layer_id, num_layers, alibi, rope_theta, is_prompt, token_idx, position_ids): + bat_0213_query, bat_0213_key, bat_0213_value = self.bias_add_transform_0213( + query_key_value, None, heads, 3, False) + + if rotary_dim > 0 and rotate_half: + from transformers.models.llama.modeling_llama import apply_rotary_pos_emb + + rotary = InferenceContext.Instance().get_rotary(rotary_dim, rope_theta, bat_0213_value.device) + cos, sin = rotary(bat_0213_value, InferenceContext.Instance().get_max_tokens_num()) + bat_0213_query, bat_0213_key = apply_rotary_pos_emb(bat_0213_query, bat_0213_key, cos, sin, position_ids) + + bat_0213_key, bat_0213_value = InferenceContext.Instance().update_cache(layer_id, token_idx, is_prompt, + bat_0213_key, bat_0213_value) + + bat_0213_key = self.repeat_kv(bat_0213_key, num_kv) + bat_0213_value = self.repeat_kv(bat_0213_value, num_kv) + + bsz = query_key_value.shape[0] + head_dim = query_key_value.shape[2] // (heads * 3) + + bmm_output = torch.bmm(bat_0213_query.reshape(bsz * heads, bat_0213_query.shape[2], head_dim), + bat_0213_key.reshape(bsz * heads, bat_0213_key.shape[2], head_dim).transpose(1, 2)) + + layer_scale = 1.0 + if alibi is not None and len(alibi.shape) > 1: + layer_scale = max(1, layer_id).to(float) + + alpha = norm_factor * norm_factor / layer_scale + bmm_output *= alpha + bmm_output_reshape = bmm_output.reshape(bsz, heads, bmm_output.shape[1], bmm_output.shape[2]) + + recompute = is_prompt + if attn_mask is not None and len(attn_mask.shape) > 1 and attn_mask.shape[-1] < bmm_output_reshape.shape[3]: + attn_mask = torch.nn.functional.pad(attn_mask, (0, bmm_output_reshape.shape[3] - attn_mask.shape[-1]), + value=torch.finfo(attn_mask.dtype).min) + softmax_output = SoftmaxOp.softmax_fallback(bmm_output_reshape, attn_mask, alibi, triangular_masking, + recompute, local_attention, window_size, None, layer_scale, 0, 1) + + output = torch.bmm(softmax_output.reshape(bsz * heads, softmax_output.shape[2], softmax_output.shape[3]), + bat_0213_value.reshape(bsz * heads, bat_0213_value.shape[2], head_dim)) + + output = output.reshape(bsz, heads, output.shape[1], head_dim) + output = output.reshape(bsz, heads, output.shape[2] * head_dim) + input_seq_len = query_key_value.shape[1] + t4d_0123_output = self.transform4d_0213(output, input_seq_len) + t4d_0123_output = t4d_0123_output.reshape(bsz, t4d_0123_output.shape[1], heads * head_dim) + + if layer_id == num_layers - 1: + InferenceContext.Instance().advance_tokens() + + return t4d_0123_output, bat_0213_key, bat_0213_value def forward(self, query_key_value: torch.Tensor, attn_mask: torch.Tensor, heads: int, num_kv: int, - norm_factor: float, no_masking: bool, layer_id: int, num_layers: int, alibi: torch.Tensor): + norm_factor: float, no_masking: bool, layer_id: int, num_layers: int, alibi: torch.Tensor, + is_prompt: bool, token_idx: torch.Tensor, position_ids: torch.Tensor): if alibi is not None: batch_heads = query_key_value.shape[0] * heads @@ -42,6 +139,6 @@ def forward(self, query_key_value: torch.Tensor, attn_mask: torch.Tensor, heads: self.config.rotate_every_two, heads, num_kv, norm_factor, self.config.triangular_masking, self.config.local_attention, self.config.window_size, no_masking, layer_id, num_layers, alibi, - self.config.rope_theta) + self.config.rope_theta, is_prompt, token_idx, position_ids) return output diff --git a/deepspeed/ops/transformer/inference/op_binding/vector_add.py b/deepspeed/ops/transformer/inference/op_binding/vector_add.py new file mode 100644 index 000000000000..015340a1084b --- /dev/null +++ b/deepspeed/ops/transformer/inference/op_binding/vector_add.py @@ -0,0 +1,28 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from ..config import DeepSpeedInferenceConfig +from .base import BaseOp + + +class VectorAddOp(BaseOp): + + def __init__(self, config: DeepSpeedInferenceConfig = None): + if config is None: + config = DeepSpeedInferenceConfig() + super(VectorAddOp, self).__init__(config) + try: + self.vector_add_func = self.inference_module._vector_add + except AttributeError: + self.vector_add_func = self.vector_add_fallback + + @classmethod + def vector_add_fallback(cls, a, b, gamma): + """Based on csrc/transformer/inference/csrc/pt_binding.cpp code of _vector_add""" + dtype = a.dtype + return (gamma * a.float() + b.float()).to(dtype) + + def forward(self, a, b, gamma): + return self.vector_add_func(a, b, gamma) diff --git a/deepspeed/ops/transformer/inference/op_binding/vector_matmul.py b/deepspeed/ops/transformer/inference/op_binding/vector_matmul.py index 011be859634d..cabab8d8c4ab 100644 --- a/deepspeed/ops/transformer/inference/op_binding/vector_matmul.py +++ b/deepspeed/ops/transformer/inference/op_binding/vector_matmul.py @@ -3,7 +3,6 @@ # DeepSpeed Team -import os import torch from ..config import DeepSpeedInferenceConfig from .base import BaseOp @@ -25,7 +24,7 @@ def __init__(self, config: DeepSpeedInferenceConfig): else: self.vector_matmul_func = self.inference_module.vector_matmul_fp16 elif self.config.dtype == torch.int8: - self.vector_matmul_func = self.inference_module.vector_matmul_fp16 + self.vector_matmul_func = self.inference_module.vector_matmul_int8 elif self.config.dtype == torch.bfloat16: self.vector_matmul_func = self.inference_module.vector_matmul_bf16 else: @@ -34,10 +33,7 @@ def __init__(self, config: DeepSpeedInferenceConfig): self.vector_matmul_func = self.vector_matmul_fallback def vector_matmul_fallback(self, input, weight, async_op, q_scale, q_int8, transpose): - if os.environ.get('DS_KI_FALLBACK') == 'True' and not transpose: - return torch.matmul(input, weight) - else: - raise NotImplementedError + return torch.matmul(input, weight.t() if transpose else weight) def forward(self, input: torch.Tensor, weight: torch.Tensor, async_op: bool = False): q_scale = weight.scale if hasattr(weight, 'scale') else torch.empty(1) diff --git a/deepspeed/ops/transformer/inference/op_binding/workspace.py b/deepspeed/ops/transformer/inference/op_binding/workspace.py new file mode 100644 index 000000000000..19de7d9576af --- /dev/null +++ b/deepspeed/ops/transformer/inference/op_binding/workspace.py @@ -0,0 +1,222 @@ +# Copyright (c) 2023 Habana Labs, Ltd. an Intel Company +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from ..config import DeepSpeedInferenceConfig +from .base import BaseOp + +minus_inf = -10000.0 +key_idx = 0 +value_idx = 1 + + +class InferenceContext: + + __instance = None + + def __init__(self): + self.kv_cache = None + self.kv_cache_elem_dtype = None + self.num_tokens = 1 + self.kv_cache_num_layers = None + self.kv_cache_size = None + self.max_out_tokens = None + self.rotary = None + self.allocate_called = False + self.static_shapes = True + + @classmethod + def Instance(cls): + if InferenceContext.__instance is None: + InferenceContext.__instance = InferenceContext() + return InferenceContext.__instance + + def gen_workspace(self, num_layers, num_heads, batch_size, prompt_len, hidden_dim, mp_size, external_cache, + elem_dtype, rank, max_out_tokens, min_out_tokens): + self.allocate_called = True + self.kv_cache = None + if not external_cache: + self.kv_cache_num_layers = num_layers + self.max_out_tokens = max_out_tokens + head_size = hidden_dim // num_heads + self.kv_cache_size = torch.Size([batch_size, (num_heads // mp_size), max_out_tokens, head_size]) + self.kv_cache_elem_dtype = elem_dtype + self.num_tokens = 0 + self.static_shapes = True + return True + + def retake_workspace(self): + return True + + def _retake_workspace(self): + assert self.allocate_called, "retake workspace called before allocate workspace" + + import deepspeed.accelerator as accelerator + if self.kv_cache is None: + self.kv_cache = [] + for layer in range(self.kv_cache_num_layers): + self.kv_cache.append((torch.zeros(self.kv_cache_size, + dtype=self.kv_cache_elem_dtype, + device=accelerator.get_accelerator().device_name()), + torch.zeros(self.kv_cache_size, + dtype=self.kv_cache_elem_dtype, + device=accelerator.get_accelerator().device_name()))) + + return True + + def update_cache(self, layer_id, token_idx, is_prompt, bat_0213_key, bat_0213_value): + has_workspace = self._retake_workspace() + assert has_workspace, "Could not allocate workspace" + + # Update current token + if is_prompt: + self.static_shapes = True + if token_idx is None: + self.static_shapes = False + InferenceContext.Instance().reset_tokens(bat_0213_key.shape[2]) + else: + InferenceContext.Instance().reset_tokens(token_idx) + + if token_idx is None: + token_idx = InferenceContext.Instance().current_tokens() + + bsz = bat_0213_key.shape[0] + + # Update cache content + if is_prompt: + cache_max_seq = self.kv_cache_size[2] + cache_max_head_dim = self.kv_cache_size[3] + seq = bat_0213_key.shape[2] + + mask = torch.arange(cache_max_seq, device=bat_0213_key.device) + mask = mask.ge(token_idx) + mask = mask.unsqueeze(-1) + mask = mask.expand([cache_max_seq, cache_max_head_dim]) + + self.kv_cache[layer_id][key_idx][:bsz, :, :seq, :].copy_(bat_0213_key) + self.kv_cache[layer_id][key_idx][:bsz, :].masked_fill_(mask, 0) + self.kv_cache[layer_id][value_idx][:bsz, :, :seq, :].copy_(bat_0213_value) + self.kv_cache[layer_id][value_idx][:bsz, :].masked_fill_(mask, 0) + else: + if self.static_shapes: + assert type(token_idx) == torch.Tensor, "token_idx is expected to be torch.Tensor" + self.kv_cache[layer_id][key_idx][:bsz].index_copy_(2, token_idx - 1, bat_0213_key) + self.kv_cache[layer_id][value_idx][:bsz].index_copy_(2, token_idx - 1, bat_0213_value) + else: + assert type(token_idx) == int, "token_idx is expected to be int" + self.kv_cache[layer_id][key_idx][:bsz, :, token_idx - 1:token_idx, :] = bat_0213_key + self.kv_cache[layer_id][value_idx][:bsz, :, token_idx - 1:token_idx, :] = bat_0213_value + + bat_0213_key = self.kv_cache[layer_id][key_idx][:bsz] + bat_0213_value = self.kv_cache[layer_id][value_idx][:bsz] + + if not self.static_shapes: + bat_0213_key = bat_0213_key[:, :, :token_idx, :] + bat_0213_value = bat_0213_value[:, :, :token_idx, :] + + return bat_0213_key, bat_0213_value + + def release_workspace(self): + self.kv_cache = None + self.rotary = None + + def reset_tokens(self, initial_tokens=1): + self.num_tokens = initial_tokens + + def current_tokens(self): + return self.num_tokens + + def advance_tokens(self): + self.num_tokens = self.num_tokens + 1 + + def get_kv_cache(self): + return self.kv_cache + + def get_rotary(self, rotary_dim, rope_theta, device=None): + if self.rotary is None: + from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding + + self.rotary = LlamaRotaryEmbedding(rotary_dim, base=rope_theta, device=device) + + return self.rotary + + def get_max_tokens_num(self): + return self.max_out_tokens + + +class WorkspaceOp(BaseOp): + + def __init__(self, config: DeepSpeedInferenceConfig = None): + if config is None: + config = DeepSpeedInferenceConfig() + self.inference_context = InferenceContext.Instance() + self._is_allocated = False + try: + super(WorkspaceOp, self).__init__(config) + if config.dtype == torch.float32: + self.allocate_workspace_func = self.inference_module.allocate_workspace_fp32 + elif config.dtype == torch.bfloat16: + self.allocate_workspace_func = self.inference_module.allocate_workspace_bf16 + else: + self.allocate_workspace_func = self.inference_module.allocate_workspace_fp16 + self.release_workspace_func = self.inference_module.release_workspace + self.retake_workspace_func = self.inference_module.retake_workspace + self.reset_cache_func = self.inference_module.reset_cache + except (ValueError, AttributeError) as e: + print(f"Using fallback functions in workspace because of {e}") + if config.dtype == torch.float32: + self.allocate_workspace_func = self.allocate_workspace_fp32_fallback + elif config.dtype == torch.bfloat16: + self.allocate_workspace_func = self.allocate_workspace_bf16_fallback + else: + self.allocate_workspace_func = self.allocate_workspace_fp16_fallback + self.release_workspace_func = self.release_workspace_fallback + self.retake_workspace_func = self.retake_workspace_fallback + self.reset_cache_func = self.reset_cache_fallback + + def allocate_workspace(self, *args, **kwargs): + self._is_allocated = True + return self.allocate_workspace_func(*args, **kwargs) + + def release_workspace(self): + self._is_allocated = False + return self.release_workspace_func() + + def reset_cache(self): + return self.reset_cache_func() if self.reset_cache_func else None + + def retake_workspace(self): + return self.retake_workspace_func() if self.retake_workspace_func else None + + def allocate_workspace_fp32_fallback(self, hidden_dim, num_heads, prompt_length, batch_size, num_layers, mp_size, + external_cache, rank, max_out_tokens, min_out_tokens): + return self.inference_context.gen_workspace(num_layers, num_heads, batch_size, prompt_length, hidden_dim, + mp_size, external_cache, torch.float, rank, max_out_tokens, + min_out_tokens) + + def allocate_workspace_bf16_fallback(self, hidden_dim, num_heads, prompt_length, batch_size, num_layers, mp_size, + external_cache, rank, max_out_tokens, min_out_tokens): + return self.inference_context.gen_workspace(num_layers, num_heads, batch_size, prompt_length, hidden_dim, + mp_size, external_cache, torch.bfloat16, rank, max_out_tokens, + min_out_tokens) + + def allocate_workspace_fp16_fallback(self, hidden_dim, num_heads, prompt_length, batch_size, num_layers, mp_size, + external_cache, rank, max_out_tokens, min_out_tokens): + return self.inference_context.gen_workspace(num_layers, num_heads, batch_size, prompt_length, hidden_dim, + mp_size, external_cache, torch.half, rank, max_out_tokens, + min_out_tokens) + + def reset_cache_fallback(self): + return self.inference_context.reset_tokens() + + def release_workspace_fallback(self): + return self.inference_context.release_workspace() + + def retake_workspace_fallback(self): + return self.inference_context.retake_workspace() + + def is_allocated(self): + return self._is_allocated diff --git a/deepspeed/ops/transformer/inference/triton/attention.py b/deepspeed/ops/transformer/inference/triton/attention.py index c05370ec74e5..6845d91b06be 100644 --- a/deepspeed/ops/transformer/inference/triton/attention.py +++ b/deepspeed/ops/transformer/inference/triton/attention.py @@ -125,7 +125,7 @@ def _triton_autotune(min_seqlen, context_4d_matmul(output, qkv, head_size) Fp16Matmul._update_autotune_table() - def ds_compute_attention(self, qkv_out, input_mask, layer_past, alibi): + def ds_compute_attention(self, qkv_out, input_mask, layer_past, alibi, is_prompt, token_idx, position_ids): if isinstance(qkv_out, list): qkv_out = qkv_out[0] @@ -143,7 +143,10 @@ def ds_compute_attention(self, qkv_out, input_mask, layer_past, alibi): no_masking=no_masking, layer_id=self.config.layer_id, num_layers=TritonSelfAttention.num_layers, - alibi=alibi) + alibi=alibi, + is_prompt=is_prompt, + token_idx=token_idx, + position_ids=position_ids) context_layer, key_layer, value_layer = attn_key_value return context_layer, key_layer, value_layer @@ -161,7 +164,8 @@ def forward( norm_w=None, norm_b=None, alibi=None, - use_triton_attention=True): + use_triton_attention=True, + **kwargs): if not self.config.pre_layer_norm: qkv_out = self.linear_func(input=input, @@ -192,10 +196,16 @@ def forward( triangular=self.triangular_masking) key_layer, value_layer = qkv[:, :, self.hidden_size:2 * self.hidden_size], qkv[:, :, 2 * self.hidden_size:] else: + is_prompt = kwargs.get("first_token", qkv_out[0].shape[1] > 1) + token_idx = kwargs.get("token_idx", None) + position_ids = kwargs.get("position_ids", None) context_layer, key_layer, value_layer = self.ds_compute_attention(qkv_out=qkv_out, input_mask=input_mask, layer_past=layer_past, - alibi=alibi) + alibi=alibi, + is_prompt=is_prompt, + toke_idx=token_idx, + position_ids=position_ids) output = self.vector_matmul_func(input=context_layer, weight=self.attn_ow) inp_norm = qkv_out[-1] diff --git a/deepspeed/ops/transformer/inference/triton/ops.py b/deepspeed/ops/transformer/inference/triton/ops.py index dd87d08d4d2c..dbed45313780 100644 --- a/deepspeed/ops/transformer/inference/triton/ops.py +++ b/deepspeed/ops/transformer/inference/triton/ops.py @@ -3,12 +3,10 @@ # DeepSpeed Team -import deepspeed -from deepspeed.ops.op_builder import InferenceBuilder import deepspeed.ops.transformer.inference.triton.matmul_ext as matmul_ext +from deepspeed.ops.transformer.inference.op_binding.layer_norm import LayerNormOp from deepspeed.ops.transformer.inference.triton.layer_norm import layer_norm, layer_norm_residual - -inference_module = None +from deepspeed.utils.types import ActivationFuncType def vector_matmul_func(input, weight, async_op, q_scale, q_int8, transposed_mode): @@ -76,15 +74,12 @@ def mlp_gemm_func(input, if use_triton_ln: mlp_input = layer_norm_residual(input, input_bias, residual, gamma, beta, epsilon) else: - global inference_module - if inference_module is None: - inference_module = InferenceBuilder().load() - mlp_input = inference_module._layer_norm_residual(input, input_bias, residual, gamma, beta, epsilon) + mlp_input = LayerNormOp.layer_norm_residual(input, input_bias, residual, gamma, beta, epsilon) # activation - if deepspeed.utils.types.ActivationFuncType(mlp_act_func_type) == deepspeed.utils.types.ActivationFuncType.GELU: + if ActivationFuncType(mlp_act_func_type) == ActivationFuncType.GELU: activation = "gelu" - elif deepspeed.utils.types.ActivationFuncType(mlp_act_func_type) == deepspeed.utils.types.ActivationFuncType.ReLU: + elif ActivationFuncType(mlp_act_func_type) == ActivationFuncType.ReLU: activation = "relu" else: activation = "" @@ -121,10 +116,7 @@ def qkv_gemm_func( if use_triton_ln: qkv_input = layer_norm(input, gamma, beta, epsilon) else: - global inference_module - if inference_module is None: - inference_module = InferenceBuilder().load() - qkv_input = inference_module.layer_norm(input, gamma, beta, epsilon) + qkv_input = LayerNormOp()(input, gamma, beta, epsilon) qkv_out = matmul_ext.matmul(qkv_input, weight, bias=(bias if add_bias else None), activation="", use_triton=True) diff --git a/deepspeed/runtime/hybrid_engine.py b/deepspeed/runtime/hybrid_engine.py index a991c4304563..8a6311bb6e83 100644 --- a/deepspeed/runtime/hybrid_engine.py +++ b/deepspeed/runtime/hybrid_engine.py @@ -17,16 +17,14 @@ from deepspeed.accelerator import get_accelerator from torch import nn from deepspeed.utils import logger - -from deepspeed.ops.op_builder import InferenceBuilder - from deepspeed.module_inject.layers import LinearLayer, Normalize, EmbeddingLayer, OPTEmbedding +from ..ops.transformer.inference.op_binding.workspace import WorkspaceOp + try: import transformers OPTLearnedPositionalEmbedding = transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding except: OPTLearnedPositionalEmbedding = None -inference_cuda_module = None class DeepSpeedHybridEngine(DeepSpeedEngine): @@ -61,12 +59,8 @@ def __init__(self, args, model, **kwargs): self._total_batch_size = None self._gather_latency = 0 - global inference_cuda_module - if inference_cuda_module is None: - builder = InferenceBuilder() - inference_cuda_module = builder.load() - self.is_lora_fused = False + self.workspace = WorkspaceOp() def convert_to_linear_transposed(self, model): @@ -160,13 +154,13 @@ def unfuse_lora_weight_non_pinned(self): def retake_inference_cache(self): if self._config.hybrid_engine.release_inference_cache: - retake_success = inference_cuda_module.retake_workspace() + retake_success = self.workspace.retake_workspace() if not retake_success: logger.warning("Unable to acquire workspace on first attempt, emptying cache and retrying.") gc.collect() get_accelerator().empty_cache() - retake_success = inference_cuda_module.retake_workspace() + retake_success = self.workspace.retake_workspace() if not retake_success: raise RuntimeError("Unable to retake inference workspace.") @@ -269,7 +263,7 @@ def generate(self, *inputs, **kwargs): self.is_lora_fused = False if self._config.hybrid_engine.release_inference_cache: - inference_cuda_module.release_workspace() + self.workspace.release_workspace() gc.collect() get_accelerator().empty_cache() diff --git a/op_builder/hpu/__init__.py b/op_builder/hpu/__init__.py index 6527ace087b5..5ad1b9a7f891 100644 --- a/op_builder/hpu/__init__.py +++ b/op_builder/hpu/__init__.py @@ -7,4 +7,5 @@ from .cpu_adam import CPUAdamBuilder from .fused_adam import FusedAdamBuilder +from .transformer_inference import InferenceBuilder from .no_impl import NotImplementedBuilder diff --git a/op_builder/hpu/transformer_inference.py b/op_builder/hpu/transformer_inference.py new file mode 100644 index 000000000000..e397c99200ec --- /dev/null +++ b/op_builder/hpu/transformer_inference.py @@ -0,0 +1,39 @@ +# Copyright (c) 2023 Habana Labs, Ltd. an Intel Company +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 +import importlib + +# DeepSpeed Team + +try: + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __deepspeed__ # noqa: F401 + from op_builder.builder import OpBuilder +except ImportError: + from deepspeed.ops.op_builder.builder import OpBuilder + + +class InferenceBuilder(OpBuilder): + BUILD_VAR = "DS_BUILD_TRANSFORMER_INFERENCE" + NAME = "transformer_inference" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=self.NAME) + + def absolute_name(self): + return f"deepspeed.ops.transformer.inference.{self.NAME}_op" + + def sources(self): + return [] + + def load(self, verbose=True): + if self.name in __class__._loaded_ops: + return __class__._loaded_ops[self.name] + + from deepspeed.git_version_info import installed_ops # noqa: F401 + if installed_ops.get(self.name, False): + op_module = importlib.import_module(self.absolute_name()) + __class__._loaded_ops[self.name] = op_module + return op_module diff --git a/tests/unit/ops/transformer/inference/test_bias_add.py b/tests/unit/ops/transformer/inference/test_bias_add.py index 843c9b889c2b..f25bbc1be692 100644 --- a/tests/unit/ops/transformer/inference/test_bias_add.py +++ b/tests/unit/ops/transformer/inference/test_bias_add.py @@ -8,12 +8,13 @@ import deepspeed from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import InferenceBuilder +from deepspeed.ops.transformer import DeepSpeedInferenceConfig +from deepspeed.ops.transformer.inference.op_binding.bias_add import BiasAddOp from .inference_test_utils import allclose, get_dtypes if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -inference_module = None torch_minor_version = None @@ -22,15 +23,8 @@ def run_bias_add_reference(activations, bias): def run_bias_add_ds(activations, bias): - global inference_module - if inference_module is None: - inference_module = InferenceBuilder().load() - if activations.dtype == torch.float16: - return inference_module.bias_add_fp16(activations, bias) - elif activations.dtype == torch.bfloat16: - return inference_module.bias_add_bf16(activations, bias) - else: - return inference_module.bias_add_fp32(activations, bias) + config = DeepSpeedInferenceConfig(dtype=activations.dtype) + return BiasAddOp(config)(activations, bias) @pytest.mark.inference_ops diff --git a/tests/unit/ops/transformer/inference/test_bias_geglu.py b/tests/unit/ops/transformer/inference/test_bias_geglu.py index d5ab13964974..05de4fbb4cf8 100644 --- a/tests/unit/ops/transformer/inference/test_bias_geglu.py +++ b/tests/unit/ops/transformer/inference/test_bias_geglu.py @@ -8,13 +8,13 @@ import deepspeed from deepspeed.ops.op_builder import InferenceBuilder from deepspeed.accelerator import get_accelerator +from deepspeed.ops.transformer.inference.op_binding.gated_activation import GatedActivationOp from deepspeed.utils.types import ActivationFuncType from .inference_test_utils import allclose, get_dtypes if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -inference_module = None torch_minor_version = None @@ -27,10 +27,7 @@ def run_bias_geglu_reference(activations, bias): def run_bias_geglu_ds(activation, bias): - global inference_module - if inference_module is None: - inference_module = InferenceBuilder().load() - return inference_module.gated_activation(activation, bias, ActivationFuncType.GATED_GELU) + return GatedActivationOp()(activation, bias, ActivationFuncType.GATED_GELU) @pytest.mark.inference_ops @@ -56,17 +53,14 @@ def run_gated_silu_reference(activations, bias): def run_gated_silu_ds(activation, bias): - global inference_module - if inference_module is None: - inference_module = InferenceBuilder().load() - return inference_module.gated_activation(activation, bias, ActivationFuncType.GATED_SILU) + return GatedActivationOp()(activation, bias, ActivationFuncType.GATED_SILU) @pytest.mark.inference_ops @pytest.mark.parametrize("batch", [1, 2]) @pytest.mark.parametrize("sequence", [1, 128, 255]) @pytest.mark.parametrize("channels", [512, 1232, 4096]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +@pytest.mark.parametrize("dtype", get_dtypes()) def test_gated_silu(batch, sequence, channels, dtype): activation = torch.randn((batch, sequence, channels * 2), dtype=dtype, device=get_accelerator().device_name()) bias = torch.randn((channels * 2), dtype=dtype, device=get_accelerator().device_name()) diff --git a/tests/unit/ops/transformer/inference/test_bias_gelu.py b/tests/unit/ops/transformer/inference/test_bias_gelu.py index fd82da51380c..b69030e87ace 100644 --- a/tests/unit/ops/transformer/inference/test_bias_gelu.py +++ b/tests/unit/ops/transformer/inference/test_bias_gelu.py @@ -8,13 +8,14 @@ import deepspeed from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import InferenceBuilder +from deepspeed.ops.transformer import DeepSpeedInferenceConfig +from deepspeed.ops.transformer.inference.op_binding.bias_gelu import BiasGeluOp from .inference_test_utils import allclose, get_dtypes from packaging import version as pkg_version if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -inference_module = None torch_minor_version = None @@ -25,15 +26,8 @@ def run_bias_gelu_reference(activations, bias): def run_bias_gelu_ds(activations, bias): - global inference_module - if inference_module is None: - inference_module = InferenceBuilder().load() - if activations.dtype == torch.float16: - return inference_module.bias_gelu_fp16(activations, bias) - elif activations.dtype == torch.bfloat16: - return inference_module.bias_gelu_bf16(activations, bias) - else: - return inference_module.bias_gelu_fp32(activations, bias) + config = DeepSpeedInferenceConfig(dtype=activations.dtype) + return BiasGeluOp(config)(activations, bias) @pytest.mark.inference_ops diff --git a/tests/unit/ops/transformer/inference/test_bias_relu.py b/tests/unit/ops/transformer/inference/test_bias_relu.py index 881af78e92cf..57134665b241 100644 --- a/tests/unit/ops/transformer/inference/test_bias_relu.py +++ b/tests/unit/ops/transformer/inference/test_bias_relu.py @@ -8,12 +8,13 @@ import deepspeed from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import InferenceBuilder +from deepspeed.ops.transformer import DeepSpeedInferenceConfig +from deepspeed.ops.transformer.inference.op_binding.bias_relu import BiasReluOp from .inference_test_utils import allclose, get_dtypes if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -inference_module = None torch_minor_version = None @@ -23,15 +24,8 @@ def run_bias_relu_reference(activations, bias): def run_bias_relu_ds(activations, bias): - global inference_module - if inference_module is None: - inference_module = InferenceBuilder().load() - if activations.dtype == torch.float16: - return inference_module.bias_relu_fp16(activations, bias) - elif activations.dtype == torch.bfloat16: - return inference_module.bias_relu_bf16(activations, bias) - else: - return inference_module.bias_relu_fp32(activations, bias) + config = DeepSpeedInferenceConfig(dtype=activations.dtype) + return BiasReluOp(config)(activations, bias) @pytest.mark.inference_ops diff --git a/tests/unit/ops/transformer/inference/test_gelu.py b/tests/unit/ops/transformer/inference/test_gelu.py index 675860b00bdb..beb74d09ab30 100644 --- a/tests/unit/ops/transformer/inference/test_gelu.py +++ b/tests/unit/ops/transformer/inference/test_gelu.py @@ -7,11 +7,12 @@ import torch import deepspeed from deepspeed.ops.op_builder import InferenceBuilder +from deepspeed.ops.transformer import DeepSpeedInferenceConfig +from deepspeed.ops.transformer.inference.op_binding.bias_gelu import BiasGeluOp if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -inference_module = None torch_minor_version = None @@ -45,13 +46,8 @@ def run_gelu_ds(activations, use_triton_ops=False): device = deepspeed.accelerator.get_accelerator().device_name() channels = activations.shape[-1] bias = torch.zeros((channels), dtype=activations.dtype, device=device) - global inference_module - if inference_module is None: - inference_module = InferenceBuilder().load() - if activations.dtype == torch.float16: - return inference_module.bias_gelu_fp16(activations, bias) - else: - return inference_module.bias_gelu_fp32(activations, bias) + config = DeepSpeedInferenceConfig(dtype=activations.dtype) + return BiasGeluOp(config)(activations, bias) @pytest.mark.inference_ops diff --git a/tests/unit/ops/transformer/inference/test_layer_norm.py b/tests/unit/ops/transformer/inference/test_layer_norm.py index 9eac612aa29c..2912807e9f43 100644 --- a/tests/unit/ops/transformer/inference/test_layer_norm.py +++ b/tests/unit/ops/transformer/inference/test_layer_norm.py @@ -8,6 +8,7 @@ import pytest from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import InferenceBuilder +from deepspeed.ops.transformer.inference.op_binding.layer_norm import LayerNormOp from .inference_test_utils import allclose, get_dtypes, assert_almost_equal try: import triton # noqa: F401 # type: ignore @@ -21,8 +22,6 @@ if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -inference_module = None - def ref_implementation(vals, gamma, beta, epsilon, channels, dtype): vals_f = vals.to(torch.float32) @@ -32,10 +31,7 @@ def ref_implementation(vals, gamma, beta, epsilon, channels, dtype): def ds_implementation(vals, gamma, beta, epsilon): - global inference_module - if inference_module is None: - inference_module = InferenceBuilder().load() - return inference_module.layer_norm(vals, gamma, beta, epsilon) + return LayerNormOp()(vals, gamma, beta, epsilon) def ds_triton_implementation(vals, gamma, beta, epsilon): @@ -83,10 +79,7 @@ def residual_ref_implementation(vals, bias, res, gamma, beta, epsilon, channels, def residual_ds_implementation(vals, bias, res, gamma, beta, epsilon): - global inference_module - if inference_module is None: - inference_module = InferenceBuilder().load() - return inference_module._layer_norm_residual(vals, bias, res, gamma, beta, epsilon) + return LayerNormOp.layer_norm_residual(vals, bias, res, gamma, beta, epsilon) def residual_ds_triton_implementation(vals, bias, res, gamma, beta, epsilon): @@ -137,10 +130,7 @@ def residual_store_ref_implementation(vals, bias, res, gamma, beta, epsilon, cha def residual_store_ds_implementation(vals, bias, res, gamma, beta, epsilon): - global inference_module - if inference_module is None: - inference_module = InferenceBuilder().load() - return inference_module.layer_norm_residual_store_pre_ln_res(vals, bias, res, gamma, beta, epsilon) + return LayerNormOp.layer_norm_residual_store_pre_ln_res(vals, bias, res, gamma, beta, epsilon) @pytest.mark.inference_ops diff --git a/tests/unit/ops/transformer/inference/test_moe_res_matmult.py b/tests/unit/ops/transformer/inference/test_moe_res_matmult.py index e1c8127a83ac..dcf9f16baaf1 100644 --- a/tests/unit/ops/transformer/inference/test_moe_res_matmult.py +++ b/tests/unit/ops/transformer/inference/test_moe_res_matmult.py @@ -8,24 +8,20 @@ import deepspeed from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import InferenceBuilder +from deepspeed.ops.transformer.inference.op_binding.moe_res_matmul import MoEResMatmulOp from .inference_test_utils import allclose, get_dtypes if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -inference_module = None - def run_moe_res_matmul_reference(residual, coef1, coef2, output): return residual * coef1 + output * coef2 def run_moe_res_matmul_ds(residual, coef, output): - global inference_module - if inference_module is None: - inference_module = InferenceBuilder().load() coef_t = coef.transpose(-1, -2).contiguous() - return inference_module.moe_res_matmul(residual, coef_t, output) + return MoEResMatmulOp()(residual, coef_t, output) @pytest.mark.inference_ops diff --git a/tests/unit/ops/transformer/inference/test_residual_add.py b/tests/unit/ops/transformer/inference/test_residual_add.py index 91830e25fc81..807da4904341 100644 --- a/tests/unit/ops/transformer/inference/test_residual_add.py +++ b/tests/unit/ops/transformer/inference/test_residual_add.py @@ -8,6 +8,8 @@ import deepspeed from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import InferenceBuilder +from deepspeed.ops.transformer import DeepSpeedInferenceConfig +from deepspeed.ops.transformer.inference.op_binding import ResidualAddOp from .inference_test_utils import get_dtypes if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: @@ -36,11 +38,6 @@ def allclose(x, y): return torch.allclose(x, y, rtol=rtol, atol=atol) -@pytest.fixture(scope="module") -def inference_module(): - return InferenceBuilder().load() - - def res_add_bias_ref(hidden_state, residual, attn_output, attn_bias, final_bias, mp_size=1, pre_attn_norm=True): if pre_attn_norm: hidden_state += (residual + final_bias + attn_output + attn_bias) / mp_size @@ -75,8 +72,8 @@ def run_residual_add_reference(hidden_state, residual, attn_output, attn_bias, f @pytest.mark.parametrize("mp_size", [1, 2]) @pytest.mark.parametrize("pre_attn_norm", [True, False]) @pytest.mark.parametrize("use_triton_ops", [True, False]) -def test_residual_add(inference_module, batch, sequence, hidden_dim, dtype, mlp_after_attn, add_bias, mp_size, - pre_attn_norm, use_triton_ops): +def test_residual_add(batch, sequence, hidden_dim, dtype, mlp_after_attn, add_bias, mp_size, pre_attn_norm, + use_triton_ops): if not deepspeed.HAS_TRITON and use_triton_ops: pytest.skip("triton has to be installed for the test") ds_out = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name()) @@ -96,19 +93,9 @@ def test_residual_add(inference_module, batch, sequence, hidden_dim, dtype, mlp_ if use_triton_ops: from deepspeed.ops.transformer.inference.triton import residual_add_bias ds_out = residual_add_bias(*res_add_args) - if dtype == torch.float16: - ds_out = inference_module.residual_add_bias_fp16(*res_add_args) - elif dtype == torch.float32: - ds_out = inference_module.residual_add_bias_fp32(*res_add_args) - elif dtype == torch.bfloat16: - ds_out = inference_module.residual_add_bias_bf16(*res_add_args) else: - if dtype == torch.float16: - ds_out = inference_module.residual_add_bias_fp16(*res_add_args) - elif dtype == torch.float32: - ds_out = inference_module.residual_add_bias_fp32(*res_add_args) - else: - raise ValueError(f"Unsupported dtype: {dtype}") + config = DeepSpeedInferenceConfig(dtype=dtype) + ds_out = ResidualAddOp(config).residual_add_func(*res_add_args) if not allclose(ds_out, ref_out): print((ds_out - ref_out).abs().max()) diff --git a/tests/unit/ops/transformer/inference/test_rms_norm.py b/tests/unit/ops/transformer/inference/test_rms_norm.py index 508a40e12e8d..fde9c9510771 100644 --- a/tests/unit/ops/transformer/inference/test_rms_norm.py +++ b/tests/unit/ops/transformer/inference/test_rms_norm.py @@ -8,13 +8,13 @@ import pytest from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import InferenceBuilder # type: ignore +from deepspeed.ops.transformer.inference.op_binding.pre_rms_norm import PreRMSNormOp +from deepspeed.ops.transformer.inference.op_binding.rms_norm import RMSNormOp from .inference_test_utils import allclose, get_dtypes if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -inference_module = None - def ref_implementation(vals, gamma, epsilon): variance = vals.to(torch.float32).pow(2).mean(-1, keepdim=True) @@ -27,10 +27,7 @@ def ref_implementation(vals, gamma, epsilon): def ds_implementation(vals, gamma, epsilon): - global inference_module - if inference_module is None: - inference_module = InferenceBuilder().load() - return inference_module.rms_norm(vals, gamma, epsilon) + return RMSNormOp()(vals, gamma, epsilon) @pytest.mark.inference_ops @@ -51,10 +48,7 @@ def test_rms_norm(batch, seq_len, channels, dtype): def pre_ds_implementation(vals, residual, gamma, epsilon): - global inference_module - if inference_module is None: - inference_module = InferenceBuilder().load() - return inference_module.pre_rms_norm(vals, residual, gamma, epsilon) + return PreRMSNormOp()(vals, residual, gamma, epsilon) def pre_ref_implementation(vals, residual, gamma, epsilon): diff --git a/tests/unit/ops/transformer/inference/test_softmax.py b/tests/unit/ops/transformer/inference/test_softmax.py index 7d0d6e14b651..4f6d69160aa7 100644 --- a/tests/unit/ops/transformer/inference/test_softmax.py +++ b/tests/unit/ops/transformer/inference/test_softmax.py @@ -11,7 +11,6 @@ if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -inference_module = None torch_minor_version = None @@ -43,6 +42,7 @@ def run_softmax_ds(input, use_triton_ops=False): def test_softmax(batch, sequence, channels, dtype, use_triton_ops): if not deepspeed.HAS_TRITON and use_triton_ops: pytest.skip("triton has to be installed for the test") + device = deepspeed.accelerator.get_accelerator().device_name() input_ds = torch.randn((batch, sequence, channels), dtype=dtype, device=device) input_ref = input_ds.clone().detach()