Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rearrange inference OPS and stop using builder.load #5490

Open
wants to merge 64 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
a4cec16
Add missing inference ops
oelayan7 May 1, 2024
4955653
Add missing fallback implementation for inference ops
oelayan7 May 1, 2024
aed7204
Use ops directly, stop using builder.load to use ops
oelayan7 May 1, 2024
c6f48c9
Add inference builder for cpu
oelayan7 Jun 4, 2024
e25d7b5
Merge branch 'master' into rearrange_ops
jomayeri Jun 6, 2024
b6ca9d8
add missing args to softmax context
oelayan7 Jun 9, 2024
3f2cf1e
Merge branch 'master' into rearrange_ops
lekurile Jun 10, 2024
2f55610
Merge branch 'master' into rearrange_ops
loadams Jun 12, 2024
2e9db19
Merge branch 'microsoft:master' into rearrange_ops
oelayan7 Jun 19, 2024
21ace48
Revert "Add inference builder for cpu"
oelayan7 Jun 13, 2024
cd93298
remove = from fstrings op, and unimplemented workspace
oelayan7 Jun 19, 2024
1117286
DS transformer set correct instance
oelayan7 Jun 20, 2024
c8ddcba
Merge branch 'microsoft:master' into rearrange_ops
oelayan7 Jun 23, 2024
6fe92d0
Merge branch 'master' into rearrange_ops
loadams Jun 26, 2024
7200aa8
Merge branch 'master' into rearrange_ops
loadams Jun 26, 2024
e73e3b2
Merge branch 'master' into rearrange_ops
loadams Jun 26, 2024
994ad36
Merge branch 'microsoft:master' into rearrange_ops
oelayan7 Jun 29, 2024
e7c8b38
fix CI issues
oelayan7 Jun 29, 2024
2bd46e2
Merge branch 'microsoft:master' into rearrange_ops
oelayan7 Jul 4, 2024
e08e7a5
remove irrelevant assert
oelayan7 Jul 4, 2024
77ededb
Merge branch 'microsoft:master' into rearrange_ops
oelayan7 Jul 10, 2024
268de18
Merge branch 'master' into rearrange_ops
loadams Jul 11, 2024
ec98a78
Merge branch 'microsoft:master' into rearrange_ops
oelayan7 Jul 29, 2024
6cea08f
Fix cuda compilation issue
oelayan7 Jul 29, 2024
85a90d3
Fix cuda compilation issue
oelayan7 Jul 29, 2024
cf837b3
Merge branch 'master' into rearrange_ops
tjruwase Jul 30, 2024
d05a939
Merge branch 'master' into rearrange_ops
lekurile Aug 5, 2024
cde5a5f
Merge branch 'master' into rearrange_ops
tjruwase Aug 10, 2024
96e8d35
Merge branch 'master' into rearrange_ops
tjruwase Aug 17, 2024
5fd5c08
Merge branch 'master' into rearrange_ops
tjruwase Aug 20, 2024
1d506db
Merge branch 'master' into rearrange_ops
loadams Aug 20, 2024
58f60f6
Merge branch 'master' into rearrange_ops
tjruwase Aug 21, 2024
c1eb49a
Merge branch 'master' into rearrange_ops
tjruwase Aug 29, 2024
ddd0021
Merge branch 'master' into rearrange_ops
tjruwase Aug 31, 2024
6d112b1
add verbos flag to add prints in workflow DS mii
oelayan7 Sep 3, 2024
2d31440
check if BF16 available before including the headers
oelayan7 Sep 4, 2024
cde622a
Merge branch 'master' into rearrange_ops
loadams Sep 4, 2024
be060c9
Merge branch 'master' into rearrange_ops
loadams Sep 4, 2024
b15b126
Merge branch 'master' into rearrange_ops
loadams Sep 6, 2024
1e5aafc
Merge branch 'microsoft:master' into rearrange_ops
oelayan7 Sep 9, 2024
58bd561
Merge branch 'master' into rearrange_ops
loadams Sep 9, 2024
8cb1dc8
Merge branch 'master' into rearrange_ops
loadams Sep 10, 2024
3aab92c
Merge branch 'master' into rearrange_ops
loadams Sep 11, 2024
5bbe040
Merge branch 'master' into rearrange_ops
loadams Sep 12, 2024
04f8b50
Merge branch 'microsoft:master' into rearrange_ops
oelayan7 Sep 18, 2024
1d9aeaf
revert workspace changes
oelayan7 Sep 18, 2024
191edd7
add comment to trigger nv-mii job
oelayan7 Sep 18, 2024
ac5f1a6
trigger ci
oelayan7 Sep 19, 2024
39b9b8c
revert changes that triggered mii job
oelayan7 Sep 22, 2024
fc74edd
Revert "revert workspace changes"
oelayan7 Sep 23, 2024
fdcf2ea
Add is_allocated to workspace
oelayan7 Sep 23, 2024
a01ed49
add reset_cache to workspace
oelayan7 Sep 25, 2024
d6a8c80
Merge branch 'microsoft:master' into rearrange_ops
oelayan7 Sep 26, 2024
a681bb6
Merge branch 'master' into rearrange_ops
loadams Sep 26, 2024
1ea0c81
Merge branch 'master' into rearrange_ops
loadams Sep 26, 2024
dc6e04e
Merge branch 'master' into rearrange_ops
loadams Sep 27, 2024
dc13454
Merge branch 'master' into rearrange_ops
loadams Sep 27, 2024
026c62f
Add missing call to allocate_workspace
oelayan7 Sep 29, 2024
ad67ede
add retake workspace func
oelayan7 Oct 1, 2024
b1a118a
Merge branch 'microsoft:master' into rearrange_ops
oelayan7 Oct 6, 2024
7a58485
Merge branch 'master' into rearrange_ops
oelayan7 Oct 8, 2024
a56848a
Merge branch 'master' into rearrange_ops
loadams Oct 8, 2024
dee13b9
Merge branch 'master' into rearrange_ops
loadams Oct 8, 2024
89cdf11
Merge branch 'master' into rearrange_ops
loadams Oct 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions csrc/transformer/inference/csrc/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,14 +452,17 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
unsigned layer_id,
unsigned num_layers,
at::Tensor& alibi,
float rope_theta)
float rope_theta,
bool is_prompt,
std::optional<at::Tensor> token_idx,
std::optional<at::Tensor> position_ids)
{
unsigned bsz = query_key_value.size(0);
unsigned seq_len = query_key_value.size(1);
int k = query_key_value.size(2) / (heads + 2 * (num_kv > 0 ? num_kv : heads));
unsigned hidden_dim = heads * k;

bool is_prompt = (seq_len > 1);
is_prompt = (seq_len > 1);
tjruwase marked this conversation as resolved.
Show resolved Hide resolved

if (is_prompt) InferenceContext::Instance().reset_tokens(seq_len);
unsigned soft_len = InferenceContext::Instance().current_tokens();
Expand Down Expand Up @@ -2031,7 +2034,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
"DeepSpeed memory allocation for GPT inference with " #_name " (CUDA)"); \
m.def("dequantize_" #_name, \
&ds_dequantize<_dtype>, \
"DeepSpeed dequantize with " #_name " (CUDA)")
"DeepSpeed dequantize with " #_name " (CUDA)");

DEF_OPS(fp32, float);
DEF_OPS(fp16, __half);
Expand Down
17 changes: 4 additions & 13 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,7 @@ def __init__(self, model, config):
DS_INFERENCE_ENABLED = True

super().__init__()

# Have to import here because inference_module is a global, but python
# globals only work at the module level and will not be updated unless
# we import it each time we init a new inference engine.
from ..model_implementations.transformers.ds_transformer import inference_module
if inference_module is not None:
if DeepSpeedTransformerInference.workspace is not None:
self.destroy()

self.module = model
Expand Down Expand Up @@ -191,15 +186,11 @@ def __init__(self, model, config):
self._is_compiled = False

def destroy(self):
# Have to import here because inference_module is a global, but python
# globals only work at the module level and will not be updated unless
# we import it each time we init a new inference engine.
from ..model_implementations.transformers.ds_transformer import inference_module
DeepSpeedTransformerInference.layer_id = 0
DeepSpeedSelfAttention.num_layers = 0
if inference_module is not None:
inference_module.release_workspace()
inference_module = None
if DeepSpeedTransformerInference.workspace.is_allocated():
DeepSpeedTransformerInference.workspace.release_workspace()
DeepSpeedTransformerInference.workspace = None

def profile_model_time(self, use_cuda_events=True):
if not self.model_profile_enabled and not self._config.enable_cuda_graph:
Expand Down
15 changes: 2 additions & 13 deletions deepspeed/model_implementations/transformers/ds_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@
# DeepSpeed Team

import torch
from deepspeed import comm as dist
from deepspeed.model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference

inference_module = None


class DeepSpeedLlama2Inference(DeepSpeedTransformerInference):
"""Initialize the DeepSpeed OPT Transformer Layer.
Expand All @@ -27,18 +24,10 @@ def forward(self, *args, **kwargs):

input = args[0]
input_mask = None
# Allocate memory only on first layer forward
if self.config.layer_id == 0 and self._alloc_workspace:
self.allocate_workspace(self.config.hidden_size, self.config.heads,
input.size()[1],
input.size()[0], DeepSpeedTransformerInference.layer_id, self.config.mp_size,
self.config.bigscience_bloom,
dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens,
self.config.min_out_tokens)
self._alloc_workspace = False

get_present = True

self.allocate_workspace(input.size())

# We set the prev key/value to None when there is a prompt
if input.shape[1] > 1:
self.layer_past = None
Expand Down
54 changes: 23 additions & 31 deletions deepspeed/model_implementations/transformers/ds_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,18 @@
import torch
import torch.nn as nn
from deepspeed import comm as dist
from deepspeed.ops.transformer.inference.op_binding.layer_norm import LayerNormOp
from deepspeed.utils.logging import log_dist

from deepspeed.ops.transformer.inference.ds_mlp import DeepSpeedMLP
from deepspeed.ops.transformer.inference.ds_attention import DeepSpeedSelfAttention, BloomSelfAttention
from deepspeed.ops.transformer.inference.op_binding.workspace import WorkspaceOp
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import InferenceBuilder
import deepspeed
if deepspeed.HAS_TRITON:
from deepspeed.ops.transformer.inference.triton.mlp import TritonMLP
from deepspeed.ops.transformer.inference.triton.attention import TritonSelfAttention

inference_module = None


class DeepSpeedTransformerInference(nn.Module):
"""Initialize the DeepSpeed Transformer Layer.
Expand All @@ -37,6 +36,7 @@ class DeepSpeedTransformerInference(nn.Module):
for specific downstream tasks.
"""
layer_id = 0
workspace = None

def __init__(self,
config,
Expand All @@ -52,10 +52,6 @@ def __init__(self,
DeepSpeedTransformerInference.layer_id += 1

data_type = torch.half if self.config.dtype == torch.int8 else self.config.dtype
global inference_module
if inference_module is None:
builder = InferenceBuilder()
inference_module = builder.load()

if DeepSpeedTransformerInference.layer_id == 1:
log_dist(f"DeepSpeed-Inference config: {self.config.__dict__}", [0])
Expand Down Expand Up @@ -88,22 +84,25 @@ def __init__(self,
self.norm_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device),
requires_grad=False)
self.layer_past = None
try:
if config.dtype == torch.float32:
self.allocate_workspace = inference_module.allocate_workspace_fp32
elif config.dtype == torch.bfloat16:
self.allocate_workspace = inference_module.allocate_workspace_bf16
else:
self.allocate_workspace = inference_module.allocate_workspace_fp32
self._alloc_workspace = True
except AttributeError:
self.allocate_workspace = None
self._alloc_workspace = False
self.layer_norm = LayerNormOp()
if DeepSpeedTransformerInference.workspace is None:
DeepSpeedTransformerInference.workspace = WorkspaceOp(self.config)
self._should_allocate_workspace = True

def allocate_workspace(self, size):
# Allocate memory only on first layer forward
if self.config.layer_id == 0 and self._should_allocate_workspace:
DeepSpeedTransformerInference.workspace.allocate_workspace(
self.config.hidden_size, self.config.heads, size[1], size[0], DeepSpeedTransformerInference.layer_id,
self.config.mp_size, self.config.bigscience_bloom,
dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens,
self.config.min_out_tokens)
self._should_allocate_workspace = False

@classmethod
def reset_cache(cls):
if inference_module is not None:
inference_module.reset_cache()
if cls.workspace is not None:
cls.workspace.reset_cache()

def forward(
self,
Expand Down Expand Up @@ -136,15 +135,7 @@ def forward(

input_mask = (input_mask if attn_mask is None else attn_mask) if attention_mask is None else attention_mask

# Allocate memory only on first layer forward
if self.config.layer_id == 0 and self._alloc_workspace:
self.allocate_workspace(self.config.hidden_size, self.config.heads,
input.size()[1],
input.size()[0], DeepSpeedTransformerInference.layer_id, self.config.mp_size,
self.config.bigscience_bloom,
dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens,
self.config.min_out_tokens)
self._alloc_workspace = False
self.allocate_workspace(input.size())

get_present = (get_present or get_key_value or use_cache)
input_mask = input_mask if attention_mask is None else attention_mask
Expand Down Expand Up @@ -178,14 +169,15 @@ def forward(
output_attentions,
self.norm_w,
self.norm_b,
alibi)
alibi,
**kwargs)

presents = (key, value)
self.layer_past = presents if layer_past is None else None
output = self.mlp(attention_output, input, inp_norm, self.attention.attn_ob)

if not self.config.pre_layer_norm:
output = inference_module.layer_norm(output, self.norm_w, self.norm_b, self.config.epsilon)
output = self.layer_norm(output, self.norm_w, self.norm_b, self.config.epsilon)

output = output.to(input_type)
if get_present:
Expand Down
1 change: 0 additions & 1 deletion deepspeed/ops/transformer/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def __init__(self,
self.return_tuple = return_tuple
self.mlp_after_attn = mlp_after_attn
self.mlp_act_func_type = mlp_act_func_type
self.specialized_mode = False
self.training_mp_size = training_mp_size
self.bigscience_bloom = bigscience_bloom
self.max_out_tokens = max_out_tokens
Expand Down
52 changes: 24 additions & 28 deletions deepspeed/ops/transformer/inference/diffusers_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
from packaging import version as pkg_version
from deepspeed.utils.logging import log_dist
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import InferenceBuilder
from deepspeed.ops.transformer.inference.op_binding.workspace import WorkspaceOp
from deepspeed.ops.transformer.inference.op_binding.softmax_context import SoftmaxContextOp
from deepspeed.ops.transformer.inference.op_binding import LinearOp
from deepspeed.ops.transformer.inference.op_binding.pad_transform import PadTransformOp

# Cuda modules will be imported if needed
inference_module = None
minus_inf = -10000.0
triton_flash_attn = None

Expand All @@ -36,7 +37,8 @@ class DeepSpeedDiffusersAttentionFunction(Function):
@staticmethod
def forward(ctx, input, context, input_mask, config, attn_qkvw, attn_qw, attn_kw, attn_vw, attn_qkvb,
num_attention_heads_per_partition, norm_factor, hidden_size_per_partition, attn_ow, attn_ob,
do_out_bias, score_context_func, linear_func, triton_flash_attn_kernel, rope_theta):
do_out_bias, score_context_func, linear_func, pad_transform_func, triton_flash_attn_kernel,
rope_theta):

def _transpose_for_context(x):
x = x.permute(0, 2, 1, 3)
Expand Down Expand Up @@ -77,7 +79,7 @@ def selfAttention_fp(input, context, input_mask):
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
query, key, value = inference_module.pad_transform_fp16(query, key, value, config.heads, do_flash_attn)
query, key, value = pad_transform_func(query, key, value, config.heads, do_flash_attn)
attention_scores = (torch.matmul(query, key.transpose(-1, -2)) * scale).softmax(dim=-1)
context_layer = _transpose_for_context(torch.matmul(attention_scores, value))

Expand Down Expand Up @@ -117,10 +119,6 @@ def __init__(

data_type = self.config.dtype
data_type_fp = torch.half if self.config.dtype == torch.int8 else self.config.dtype
global inference_module
if inference_module is None:
builder = InferenceBuilder()
inference_module = builder.load()

if DeepSpeedDiffusersAttention.layer_id == 1:
log_dist(f"DeepSpeed-Attention config: {self.config.__dict__}", [0])
Expand Down Expand Up @@ -171,26 +169,24 @@ def __init__(
self.norm_factor *= math.sqrt(self.config.layer_id + 1)
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/gpt2/modeling_gpt2.py#L191

if self.config.dtype in [torch.float16, torch.int8]:
self.score_context_func = inference_module.softmax_context_fp16
self.linear_func = inference_module.linear_layer_fp16
self.allocate_workspace = inference_module.allocate_workspace_fp16
else:
self.score_context_func = inference_module.softmax_context_fp32
self.linear_func = inference_module.linear_layer_fp32
self.allocate_workspace = inference_module.allocate_workspace_fp32
self.workspace = WorkspaceOp(self.config)
self.score_context_func = SoftmaxContextOp(self.config)
self.linear_func = LinearOp(self.config)
self.pad_transform_func = PadTransformOp(self.config)

def forward(self, input, context=None, input_mask=None):
def allocate_workspace(self, size):
# Allocate memory only on first layer forward
if self.config.layer_id == 0:
self.allocate_workspace(self.config.hidden_size, self.config.heads,
input.size()[1],
input.size()[0], DeepSpeedDiffusersAttention.layer_id, self.config.mp_size, False,
0, self.config.max_out_tokens, self.config.min_out_tokens)
output = DeepSpeedDiffusersAttentionFunction.apply(input, context, input_mask, self.config, self.attn_qkvw,
self.attn_qw, self.attn_kw, self.attn_vw, self.attn_qkvb,
self.num_attention_heads_per_partition, self.norm_factor,
self.hidden_size_per_partition, self.attn_ow, self.attn_ob,
self.do_out_bias, self.score_context_func, self.linear_func,
self.triton_flash_attn_kernel, self.config.rope_theta)
self.workspace.allocate_workspace(self.config.hidden_size, self.config.heads, size[1], size[0],
DeepSpeedDiffusersAttention.layer_id, self.config.mp_size, False, 0,
self.config.max_out_tokens, self.config.min_out_tokens)

def forward(self, input, context=None, input_mask=None):
self.allocate_workspace(input.size())
output = DeepSpeedDiffusersAttentionFunction.apply(
input, context, input_mask, self.config, self.attn_qkvw, self.attn_qw, self.attn_kw, self.attn_vw,
self.attn_qkvb, self.num_attention_heads_per_partition, self.norm_factor, self.hidden_size_per_partition,
self.attn_ow, self.attn_ob, self.do_out_bias, self.score_context_func, self.linear_func,
self.pad_transform_func, self.triton_flash_attn_kernel, self.config.rope_theta)

return output
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,9 @@
from .diffusers_attention import DeepSpeedDiffusersAttention
from .bias_add import nhwc_bias_add
from .diffusers_2d_transformer import Diffusers2DTransformerConfig
from deepspeed.ops.op_builder import InferenceBuilder, SpatialInferenceBuilder
from deepspeed.utils.types import ActivationFuncType

# Ops will be loaded on demand
transformer_cuda_module = None
spatial_cuda_module = None


def load_transformer_module():
global transformer_cuda_module
if transformer_cuda_module is None:
transformer_cuda_module = InferenceBuilder().load()
return transformer_cuda_module


def load_spatial_module():
global spatial_cuda_module
if spatial_cuda_module is None:
spatial_cuda_module = SpatialInferenceBuilder().load()
return spatial_cuda_module
from .op_binding.gated_activation import GatedActivationOp
from .op_binding.layer_norm import LayerNormOp


class DeepSpeedDiffusersTransformerBlock(nn.Module):
Expand Down Expand Up @@ -76,8 +59,8 @@ def __init__(self, equivalent_module: nn.Module, config: Diffusers2DTransformerC
else:
self.attn_2_bias = nn.Paramaeter(torch.zeros_like(self.norm3_g), requires_grad=False)

self.transformer_cuda_module = load_transformer_module()
load_spatial_module()
self.gated_activation = GatedActivationOp()
self.layer_norm = LayerNormOp()

def forward(self, hidden_states, context=None, timestep=None, **kwargs):
# In v0.12.0 of diffuser, several new kwargs were added. Capturing
Expand All @@ -88,17 +71,17 @@ def forward(self, hidden_states, context=None, timestep=None, **kwargs):
if "encoder_hidden_states" in kwargs and kwargs["encoder_hidden_states"] is not None:
context = kwargs["encoder_hidden_states"]

out_norm_1 = self.transformer_cuda_module.layer_norm(hidden_states, self.norm1_g, self.norm1_b, self.norm1_eps)
out_norm_1 = self.layer_norm(hidden_states, self.norm1_g, self.norm1_b, self.norm1_eps)
out_attn_1 = self.attn_1(out_norm_1)

out_norm_2, out_attn_1 = self.transformer_cuda_module.layer_norm_residual_store_pre_ln_res(
out_norm_2, out_attn_1 = self.layer_norm.layer_norm_residual_store_pre_ln_res(
out_attn_1, self.attn_1_bias, hidden_states, self.norm2_g, self.norm2_b, self.norm2_eps)
out_attn_2 = self.attn_2(out_norm_2, context=context)
out_norm_3, out_attn_2 = self.transformer_cuda_module.layer_norm_residual_store_pre_ln_res(
out_norm_3, out_attn_2 = self.layer_norm.layer_norm_residual_store_pre_ln_res(
out_attn_2, self.attn_2_bias, out_attn_1, self.norm3_g, self.norm3_b, self.norm3_eps)

out_ff1 = nn.functional.linear(out_norm_3, self.ff1_w)
out_geglu = self.transformer_cuda_module.gated_activation(out_ff1, self.ff1_b, ActivationFuncType.GATED_GELU)
out_geglu = self.gated_activation(out_ff1, self.ff1_b, ActivationFuncType.GATED_GELU)

out_ff2 = nn.functional.linear(out_geglu, self.ff2_w)
return nhwc_bias_add(out_ff2, self.ff2_b, other=out_attn_2)
Loading
Loading