Skip to content

Commit

Permalink
Rearrange inference OPS and stop using builder.load (#5490)
Browse files Browse the repository at this point in the history
This PR mainly handles all places where InferenceBuilder is used to
access any op or a specific implementation for an op.
Instead an op is defined, and its proper implementation is picked inside
and the usage will be transparent to the user.
What was done in the PR:
1) Added missing ops (added a py file with fallback mechanism)
2) Added missing fallback implementations for existing ops
3) removed all usages for builder.load and replaced them with ops
instead.
4) added workspace op and inferenceContext which contains all workspace
related functions and inferenceContext is the python fallback of
inferenceContext in CUDA
5) a small change to softmax_context signature to fit the fallback
signature.

---------

Co-authored-by: Joe Mayer <114769929+jomayeri@users.noreply.github.com>
Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
  • Loading branch information
5 people authored Oct 8, 2024
1 parent 00c4b98 commit 030cf4a
Show file tree
Hide file tree
Showing 44 changed files with 1,064 additions and 363 deletions.
9 changes: 6 additions & 3 deletions csrc/transformer/inference/csrc/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,14 +452,17 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
unsigned layer_id,
unsigned num_layers,
at::Tensor& alibi,
float rope_theta)
float rope_theta,
bool is_prompt,
std::optional<at::Tensor> token_idx,
std::optional<at::Tensor> position_ids)
{
unsigned bsz = query_key_value.size(0);
unsigned seq_len = query_key_value.size(1);
int k = query_key_value.size(2) / (heads + 2 * (num_kv > 0 ? num_kv : heads));
unsigned hidden_dim = heads * k;

bool is_prompt = (seq_len > 1);
is_prompt = (seq_len > 1);

if (is_prompt) InferenceContext::Instance().reset_tokens(seq_len);
unsigned soft_len = InferenceContext::Instance().current_tokens();
Expand Down Expand Up @@ -2031,7 +2034,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
"DeepSpeed memory allocation for GPT inference with " #_name " (CUDA)"); \
m.def("dequantize_" #_name, \
&ds_dequantize<_dtype>, \
"DeepSpeed dequantize with " #_name " (CUDA)")
"DeepSpeed dequantize with " #_name " (CUDA)");

DEF_OPS(fp32, float);
DEF_OPS(fp16, __half);
Expand Down
17 changes: 4 additions & 13 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,7 @@ def __init__(self, model, config):
DS_INFERENCE_ENABLED = True

super().__init__()

# Have to import here because inference_module is a global, but python
# globals only work at the module level and will not be updated unless
# we import it each time we init a new inference engine.
from ..model_implementations.transformers.ds_transformer import inference_module
if inference_module is not None:
if DeepSpeedTransformerInference.workspace is not None:
self.destroy()

self.module = model
Expand Down Expand Up @@ -191,15 +186,11 @@ def __init__(self, model, config):
self._is_compiled = False

def destroy(self):
# Have to import here because inference_module is a global, but python
# globals only work at the module level and will not be updated unless
# we import it each time we init a new inference engine.
from ..model_implementations.transformers.ds_transformer import inference_module
DeepSpeedTransformerInference.layer_id = 0
DeepSpeedSelfAttention.num_layers = 0
if inference_module is not None:
inference_module.release_workspace()
inference_module = None
if DeepSpeedTransformerInference.workspace.is_allocated():
DeepSpeedTransformerInference.workspace.release_workspace()
DeepSpeedTransformerInference.workspace = None

def profile_model_time(self, use_cuda_events=True):
if not self.model_profile_enabled and not self._config.enable_cuda_graph:
Expand Down
15 changes: 2 additions & 13 deletions deepspeed/model_implementations/transformers/ds_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@
# DeepSpeed Team

import torch
from deepspeed import comm as dist
from deepspeed.model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference

inference_module = None


class DeepSpeedLlama2Inference(DeepSpeedTransformerInference):
"""Initialize the DeepSpeed OPT Transformer Layer.
Expand All @@ -27,18 +24,10 @@ def forward(self, *args, **kwargs):

input = args[0]
input_mask = None
# Allocate memory only on first layer forward
if self.config.layer_id == 0 and self._alloc_workspace:
self.allocate_workspace(self.config.hidden_size, self.config.heads,
input.size()[1],
input.size()[0], DeepSpeedTransformerInference.layer_id, self.config.mp_size,
self.config.bigscience_bloom,
dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens,
self.config.min_out_tokens)
self._alloc_workspace = False

get_present = True

self.allocate_workspace(input.size())

# We set the prev key/value to None when there is a prompt
if input.shape[1] > 1:
self.layer_past = None
Expand Down
54 changes: 23 additions & 31 deletions deepspeed/model_implementations/transformers/ds_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,18 @@
import torch
import torch.nn as nn
from deepspeed import comm as dist
from deepspeed.ops.transformer.inference.op_binding.layer_norm import LayerNormOp
from deepspeed.utils.logging import log_dist

from deepspeed.ops.transformer.inference.ds_mlp import DeepSpeedMLP
from deepspeed.ops.transformer.inference.ds_attention import DeepSpeedSelfAttention, BloomSelfAttention
from deepspeed.ops.transformer.inference.op_binding.workspace import WorkspaceOp
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import InferenceBuilder
import deepspeed
if deepspeed.HAS_TRITON:
from deepspeed.ops.transformer.inference.triton.mlp import TritonMLP
from deepspeed.ops.transformer.inference.triton.attention import TritonSelfAttention

inference_module = None


class DeepSpeedTransformerInference(nn.Module):
"""Initialize the DeepSpeed Transformer Layer.
Expand All @@ -37,6 +36,7 @@ class DeepSpeedTransformerInference(nn.Module):
for specific downstream tasks.
"""
layer_id = 0
workspace = None

def __init__(self,
config,
Expand All @@ -52,10 +52,6 @@ def __init__(self,
DeepSpeedTransformerInference.layer_id += 1

data_type = torch.half if self.config.dtype == torch.int8 else self.config.dtype
global inference_module
if inference_module is None:
builder = InferenceBuilder()
inference_module = builder.load()

if DeepSpeedTransformerInference.layer_id == 1:
log_dist(f"DeepSpeed-Inference config: {self.config.__dict__}", [0])
Expand Down Expand Up @@ -88,22 +84,25 @@ def __init__(self,
self.norm_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device),
requires_grad=False)
self.layer_past = None
try:
if config.dtype == torch.float32:
self.allocate_workspace = inference_module.allocate_workspace_fp32
elif config.dtype == torch.bfloat16:
self.allocate_workspace = inference_module.allocate_workspace_bf16
else:
self.allocate_workspace = inference_module.allocate_workspace_fp32
self._alloc_workspace = True
except AttributeError:
self.allocate_workspace = None
self._alloc_workspace = False
self.layer_norm = LayerNormOp()
if DeepSpeedTransformerInference.workspace is None:
DeepSpeedTransformerInference.workspace = WorkspaceOp(self.config)
self._should_allocate_workspace = True

def allocate_workspace(self, size):
# Allocate memory only on first layer forward
if self.config.layer_id == 0 and self._should_allocate_workspace:
DeepSpeedTransformerInference.workspace.allocate_workspace(
self.config.hidden_size, self.config.heads, size[1], size[0], DeepSpeedTransformerInference.layer_id,
self.config.mp_size, self.config.bigscience_bloom,
dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens,
self.config.min_out_tokens)
self._should_allocate_workspace = False

@classmethod
def reset_cache(cls):
if inference_module is not None:
inference_module.reset_cache()
if cls.workspace is not None:
cls.workspace.reset_cache()

def forward(
self,
Expand Down Expand Up @@ -136,15 +135,7 @@ def forward(

input_mask = (input_mask if attn_mask is None else attn_mask) if attention_mask is None else attention_mask

# Allocate memory only on first layer forward
if self.config.layer_id == 0 and self._alloc_workspace:
self.allocate_workspace(self.config.hidden_size, self.config.heads,
input.size()[1],
input.size()[0], DeepSpeedTransformerInference.layer_id, self.config.mp_size,
self.config.bigscience_bloom,
dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens,
self.config.min_out_tokens)
self._alloc_workspace = False
self.allocate_workspace(input.size())

get_present = (get_present or get_key_value or use_cache)
input_mask = input_mask if attention_mask is None else attention_mask
Expand Down Expand Up @@ -178,14 +169,15 @@ def forward(
output_attentions,
self.norm_w,
self.norm_b,
alibi)
alibi,
**kwargs)

presents = (key, value)
self.layer_past = presents if layer_past is None else None
output = self.mlp(attention_output, input, inp_norm, self.attention.attn_ob)

if not self.config.pre_layer_norm:
output = inference_module.layer_norm(output, self.norm_w, self.norm_b, self.config.epsilon)
output = self.layer_norm(output, self.norm_w, self.norm_b, self.config.epsilon)

output = output.to(input_type)
if get_present:
Expand Down
1 change: 0 additions & 1 deletion deepspeed/ops/transformer/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def __init__(self,
self.return_tuple = return_tuple
self.mlp_after_attn = mlp_after_attn
self.mlp_act_func_type = mlp_act_func_type
self.specialized_mode = False
self.training_mp_size = training_mp_size
self.bigscience_bloom = bigscience_bloom
self.max_out_tokens = max_out_tokens
Expand Down
52 changes: 24 additions & 28 deletions deepspeed/ops/transformer/inference/diffusers_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
from packaging import version as pkg_version
from deepspeed.utils.logging import log_dist
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import InferenceBuilder
from deepspeed.ops.transformer.inference.op_binding.workspace import WorkspaceOp
from deepspeed.ops.transformer.inference.op_binding.softmax_context import SoftmaxContextOp
from deepspeed.ops.transformer.inference.op_binding import LinearOp
from deepspeed.ops.transformer.inference.op_binding.pad_transform import PadTransformOp

# Cuda modules will be imported if needed
inference_module = None
minus_inf = -10000.0
triton_flash_attn = None

Expand All @@ -36,7 +37,8 @@ class DeepSpeedDiffusersAttentionFunction(Function):
@staticmethod
def forward(ctx, input, context, input_mask, config, attn_qkvw, attn_qw, attn_kw, attn_vw, attn_qkvb,
num_attention_heads_per_partition, norm_factor, hidden_size_per_partition, attn_ow, attn_ob,
do_out_bias, score_context_func, linear_func, triton_flash_attn_kernel, rope_theta):
do_out_bias, score_context_func, linear_func, pad_transform_func, triton_flash_attn_kernel,
rope_theta):

def _transpose_for_context(x):
x = x.permute(0, 2, 1, 3)
Expand Down Expand Up @@ -77,7 +79,7 @@ def selfAttention_fp(input, context, input_mask):
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
query, key, value = inference_module.pad_transform_fp16(query, key, value, config.heads, do_flash_attn)
query, key, value = pad_transform_func(query, key, value, config.heads, do_flash_attn)
attention_scores = (torch.matmul(query, key.transpose(-1, -2)) * scale).softmax(dim=-1)
context_layer = _transpose_for_context(torch.matmul(attention_scores, value))

Expand Down Expand Up @@ -117,10 +119,6 @@ def __init__(

data_type = self.config.dtype
data_type_fp = torch.half if self.config.dtype == torch.int8 else self.config.dtype
global inference_module
if inference_module is None:
builder = InferenceBuilder()
inference_module = builder.load()

if DeepSpeedDiffusersAttention.layer_id == 1:
log_dist(f"DeepSpeed-Attention config: {self.config.__dict__}", [0])
Expand Down Expand Up @@ -171,26 +169,24 @@ def __init__(
self.norm_factor *= math.sqrt(self.config.layer_id + 1)
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/gpt2/modeling_gpt2.py#L191

if self.config.dtype in [torch.float16, torch.int8]:
self.score_context_func = inference_module.softmax_context_fp16
self.linear_func = inference_module.linear_layer_fp16
self.allocate_workspace = inference_module.allocate_workspace_fp16
else:
self.score_context_func = inference_module.softmax_context_fp32
self.linear_func = inference_module.linear_layer_fp32
self.allocate_workspace = inference_module.allocate_workspace_fp32
self.workspace = WorkspaceOp(self.config)
self.score_context_func = SoftmaxContextOp(self.config)
self.linear_func = LinearOp(self.config)
self.pad_transform_func = PadTransformOp(self.config)

def forward(self, input, context=None, input_mask=None):
def allocate_workspace(self, size):
# Allocate memory only on first layer forward
if self.config.layer_id == 0:
self.allocate_workspace(self.config.hidden_size, self.config.heads,
input.size()[1],
input.size()[0], DeepSpeedDiffusersAttention.layer_id, self.config.mp_size, False,
0, self.config.max_out_tokens, self.config.min_out_tokens)
output = DeepSpeedDiffusersAttentionFunction.apply(input, context, input_mask, self.config, self.attn_qkvw,
self.attn_qw, self.attn_kw, self.attn_vw, self.attn_qkvb,
self.num_attention_heads_per_partition, self.norm_factor,
self.hidden_size_per_partition, self.attn_ow, self.attn_ob,
self.do_out_bias, self.score_context_func, self.linear_func,
self.triton_flash_attn_kernel, self.config.rope_theta)
self.workspace.allocate_workspace(self.config.hidden_size, self.config.heads, size[1], size[0],
DeepSpeedDiffusersAttention.layer_id, self.config.mp_size, False, 0,
self.config.max_out_tokens, self.config.min_out_tokens)

def forward(self, input, context=None, input_mask=None):
self.allocate_workspace(input.size())
output = DeepSpeedDiffusersAttentionFunction.apply(
input, context, input_mask, self.config, self.attn_qkvw, self.attn_qw, self.attn_kw, self.attn_vw,
self.attn_qkvb, self.num_attention_heads_per_partition, self.norm_factor, self.hidden_size_per_partition,
self.attn_ow, self.attn_ob, self.do_out_bias, self.score_context_func, self.linear_func,
self.pad_transform_func, self.triton_flash_attn_kernel, self.config.rope_theta)

return output
33 changes: 8 additions & 25 deletions deepspeed/ops/transformer/inference/diffusers_transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,9 @@
from .diffusers_attention import DeepSpeedDiffusersAttention
from .bias_add import nhwc_bias_add
from .diffusers_2d_transformer import Diffusers2DTransformerConfig
from deepspeed.ops.op_builder import InferenceBuilder, SpatialInferenceBuilder
from deepspeed.utils.types import ActivationFuncType

# Ops will be loaded on demand
transformer_cuda_module = None
spatial_cuda_module = None


def load_transformer_module():
global transformer_cuda_module
if transformer_cuda_module is None:
transformer_cuda_module = InferenceBuilder().load()
return transformer_cuda_module


def load_spatial_module():
global spatial_cuda_module
if spatial_cuda_module is None:
spatial_cuda_module = SpatialInferenceBuilder().load()
return spatial_cuda_module
from .op_binding.gated_activation import GatedActivationOp
from .op_binding.layer_norm import LayerNormOp


class DeepSpeedDiffusersTransformerBlock(nn.Module):
Expand Down Expand Up @@ -76,8 +59,8 @@ def __init__(self, equivalent_module: nn.Module, config: Diffusers2DTransformerC
else:
self.attn_2_bias = nn.Paramaeter(torch.zeros_like(self.norm3_g), requires_grad=False)

self.transformer_cuda_module = load_transformer_module()
load_spatial_module()
self.gated_activation = GatedActivationOp()
self.layer_norm = LayerNormOp()

def forward(self, hidden_states, context=None, timestep=None, **kwargs):
# In v0.12.0 of diffuser, several new kwargs were added. Capturing
Expand All @@ -88,17 +71,17 @@ def forward(self, hidden_states, context=None, timestep=None, **kwargs):
if "encoder_hidden_states" in kwargs and kwargs["encoder_hidden_states"] is not None:
context = kwargs["encoder_hidden_states"]

out_norm_1 = self.transformer_cuda_module.layer_norm(hidden_states, self.norm1_g, self.norm1_b, self.norm1_eps)
out_norm_1 = self.layer_norm(hidden_states, self.norm1_g, self.norm1_b, self.norm1_eps)
out_attn_1 = self.attn_1(out_norm_1)

out_norm_2, out_attn_1 = self.transformer_cuda_module.layer_norm_residual_store_pre_ln_res(
out_norm_2, out_attn_1 = self.layer_norm.layer_norm_residual_store_pre_ln_res(
out_attn_1, self.attn_1_bias, hidden_states, self.norm2_g, self.norm2_b, self.norm2_eps)
out_attn_2 = self.attn_2(out_norm_2, context=context)
out_norm_3, out_attn_2 = self.transformer_cuda_module.layer_norm_residual_store_pre_ln_res(
out_norm_3, out_attn_2 = self.layer_norm.layer_norm_residual_store_pre_ln_res(
out_attn_2, self.attn_2_bias, out_attn_1, self.norm3_g, self.norm3_b, self.norm3_eps)

out_ff1 = nn.functional.linear(out_norm_3, self.ff1_w)
out_geglu = self.transformer_cuda_module.gated_activation(out_ff1, self.ff1_b, ActivationFuncType.GATED_GELU)
out_geglu = self.gated_activation(out_ff1, self.ff1_b, ActivationFuncType.GATED_GELU)

out_ff2 = nn.functional.linear(out_geglu, self.ff2_w)
return nhwc_bias_add(out_ff2, self.ff2_b, other=out_attn_2)
Loading

0 comments on commit 030cf4a

Please sign in to comment.