Skip to content

Commit

Permalink
Optimized some scattered optimization points in the framework
Browse files Browse the repository at this point in the history
  • Loading branch information
isky-cd committed Apr 8, 2024
1 parent 04aca9e commit 57a9574
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 44 deletions.
8 changes: 4 additions & 4 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]
0, max_num_blocks
) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
output_tensor = torch.zeros(
output_tensor = torch.empty(
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
)
fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor
Expand Down Expand Up @@ -371,13 +371,13 @@ def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor,

sequence_lengths = batch.get_sequence_lengths()
if batch.is_prompts:
output_tensor = torch.zeros(
(sequence_lengths.sum().item(), batch.num_heads * batch.head_dim),
output_tensor = torch.empty(
(input_ids.size(0), batch.num_heads * batch.head_dim),
dtype=batch.dtype,
device=batch.device,
)
else:
output_tensor = torch.zeros(
output_tensor = torch.empty(
(batch.current_batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
)

Expand Down
12 changes: 7 additions & 5 deletions colossalai/inference/core/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,13 @@ def search_tokens(self, generation_config: GenerationConfig, logits):
Sample tokens for finished requests.
"""
# do logit processor
# NOTE: need to decide the granularity to process logits (sequence or batch)
config_dict = generation_config.to_dict()
for type in ["top_k", "top_p", "min_p"]:
if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type])
top_p = generation_config.top_p
top_k = generation_config.top_k

if top_k:
logits = logit_processor("top_k", logits, top_k)
if top_p:
logits = logit_processor("top_p", logits, top_p)

# calculate probs
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
Expand Down
40 changes: 26 additions & 14 deletions colossalai/inference/kv_cache/kvcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def get_block_table_kv_ptrs(self, block_table: torch.Tensor, layer_id: int) -> T
"""Get the key and value pointers of physical caches (of specific layer) corresponding to logical cache blocks indicated by the block table."""
k_ptrs = []
v_ptrs = []
for block_id in block_table:
for block_id in block_table.tolist():
if block_id >= 0:
block: CacheBlock = self._cache_blocks[block_id]
k_ptrs.append(block.k_ptrs[layer_id])
Expand Down Expand Up @@ -223,46 +223,58 @@ def allocate_context_from_block_tables(self, block_tables: torch.Tensor, context
self._block_states_cum[:-num_blocks_required],
out=self._block_finder[num_blocks_required - 1 :],
)

end_indexes = torch.nonzero(self._block_finder == num_blocks_required, as_tuple=False).view(-1)

block_tables_list = block_tables.tolist()
blocks_required_list = blocks_required.tolist()

if end_indexes.numel() > 0:
# contiguous cache exists
end_idx = end_indexes[0].item() + 1 # open interval
start_idx = end_idx - num_blocks_required # closed interval
alloc_block_ids = torch.arange(start_idx, end_idx)
alloc_block_ids = torch.arange(start_idx, end_idx, device=block_tables.device)

for i in range(bsz):
curr_required = blocks_required[i]
block_tables[i, :curr_required] = torch.arange(
start_idx, start_idx + curr_required, device=block_tables.device
)
curr_required = blocks_required_list[i]
for j in range(curr_required):
block_tables_list[i][j] = start_idx + j
start_idx += curr_required
else:
# non-contiguous cache
available_block_ids = torch.nonzero(self._block_states > 0).view(-1)
alloc_block_ids = available_block_ids[:num_blocks_required]
alloc_block_ids = alloc_block_ids.to(dtype=block_tables.dtype, device=block_tables.device)
alloc_block_ids_list = alloc_block_ids.tolist()
start_idx = 0
for i in range(bsz):
curr_required = blocks_required[i]
block_tables[i, :curr_required] = alloc_block_ids[start_idx, start_idx + curr_required]
curr_required = blocks_required_list[i]
for j in range(curr_required):
block_tables_list[i][j] = alloc_block_ids_list[start_idx + j]
start_idx += curr_required

block_tables.copy_(torch.tensor(block_tables_list))

# Update cache blocks
self._block_states[alloc_block_ids] = 0
self._available_blocks -= num_blocks_required
last_block_locs = torch.cumsum(blocks_required, dim=0) - 1
last_block_locs = last_block_locs.to(device=alloc_block_ids.device)

for i, block_id in enumerate(alloc_block_ids[last_block_locs]):
last_block_ids = alloc_block_ids[last_block_locs].tolist()
alloc_block_ids = alloc_block_ids.tolist()
context_lengths = context_lengths.tolist()

for i, block_id in enumerate(last_block_ids):
block: CacheBlock = self._cache_blocks[block_id]
block.add_ref()
self._allocate_on_block(
block,
block.block_size
if context_lengths[i] % block.block_size == 0
else context_lengths[i].item() % block.block_size,
else context_lengths[i] % block.block_size,
)
for block_id in alloc_block_ids:
if block_id in alloc_block_ids[last_block_locs]:
if block_id in last_block_ids:
continue
block: CacheBlock = self._cache_blocks[block_id]
block.add_ref()
Expand Down Expand Up @@ -336,15 +348,15 @@ def allocate_tokens_from_block_tables(
dtype=block_tables.dtype, device=block_tables.device
)

for block_id in alloc_block_ids:
for block_id in alloc_block_ids.tolist():
block: CacheBlock = self._cache_blocks[block_id]
block.add_ref()
self._block_states[block_id] = 0
self._available_blocks -= 1
block_tables[seqs_req_new_blocks, alloc_local_block_indexes[seqs_req_new_blocks]] = alloc_block_ids
block_global_ids = block_tables[torch.arange(0, bsz), alloc_local_block_indexes]

for block_id in block_global_ids:
for block_id in block_global_ids.tolist():
self._allocate_on_block(self._cache_blocks[block_id], 1)

return seqs_to_recycle
Expand Down
13 changes: 9 additions & 4 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ def llama_model_forward(
sm_scale = 1.0 / (inputmetadata.head_dim**0.5)

norm_output = torch.empty_like(hidden_states)
silu_and_mul_output = torch.empty(
hidden_states.size(0), self.config.intermediate_size, dtype=hidden_states.dtype, device=hidden_states.device
)
residual = None

for layer_id, decoder_layer in enumerate(self.layers):
Expand All @@ -137,6 +140,7 @@ def llama_model_forward(
kv_seq_len=kv_seq_len,
output_tensor=output_tensor,
norm_output=norm_output,
silu_and_mul_output=silu_and_mul_output,
sm_scale=sm_scale,
use_cuda_kernel=use_cuda_kernel,
cu_seqlens=cu_seqlens,
Expand Down Expand Up @@ -167,6 +171,7 @@ def llama_decoder_layer_forward(
kv_seq_len: int = 0,
output_tensor: torch.Tensor = None,
norm_output: torch.Tensor = None,
silu_and_mul_output: torch.Tensor = None,
sm_scale: int = None,
use_cuda_kernel: bool = True,
cu_seqlens: torch.Tensor = None,
Expand Down Expand Up @@ -216,7 +221,7 @@ def llama_decoder_layer_forward(

# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual, use_cuda_kernel)
hidden_states = self.mlp(hidden_states)
hidden_states = self.mlp(hidden_states, silu_and_mul_output)

return hidden_states, residual

Expand Down Expand Up @@ -481,12 +486,12 @@ def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP:

return mlp_layer

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, silu_and_mul_output: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
"""
hidden_states = hidden_states.expand(2, -1, -1)
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
act_out = inference_ops.silu_and_mul(gate_up_proj_out)
return torch.mm(act_out, self.down_proj_weight)
inference_ops.silu_and_mul(gate_up_proj_out, silu_and_mul_output)
return torch.mm(silu_and_mul_output, self.down_proj_weight)
1 change: 1 addition & 0 deletions examples/inference/benchmark_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def benchmark_inference(args):
max_output_len=args.output_len,
prefill_ratio=1.2,
block_size=32,
use_cuda_kernel=True,
)
engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
elif args.mode == "vllm":
Expand Down
13 changes: 1 addition & 12 deletions extensions/csrc/cuda/activation_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,8 @@ __global__ void act_and_mul_kernel(

// Note(LiuYang):This func is designed for calculation mode like
// silu(x[:half_1stdim]) * (x[half_1stdim:])
torch::Tensor silu_and_mul(const torch::Tensor& ins)
void silu_and_mul(const torch::Tensor& ins, torch::Tensor& outs)
{
// Note(LiuYang): According to torch doc, vec() may cost a lot, but I did't find a better api
// to manipulate ins_shape which is IntArrayRef
auto ins_shape = ins.sizes().vec();

ins_shape[0] = ins_shape[0]/2;
if (ins_shape[0] == 1) {
ins_shape.erase(ins_shape.begin());
}
auto outs = torch::zeros(ins_shape,ins.options());

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

// Note(Liuyang): numel of ins must be divisible by 2
Expand All @@ -71,5 +61,4 @@ torch::Tensor silu_and_mul(const torch::Tensor& ins)
);)

AT_CUDA_CHECK(cudaGetLastError());
return outs;
}
5 changes: 2 additions & 3 deletions extensions/csrc/cuda/pybind/inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void rotary_embedding_and_cache_copy(
torch::Tensor& block_tables, // [batch_size, max_seq_len]
bool high_precision);

torch::Tensor silu_and_mul(const torch::Tensor& ins);
void silu_and_mul(const torch::Tensor& ins, torch::Tensor& outs);

void rms_layernorm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
Expand Down Expand Up @@ -80,6 +80,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_add_rms_layernorm", &fused_add_rms_layernorm,
"In-place fused Add and RMS Normalization.");

m.def("get_cos_and_sin", &get_cos_and_sin,
"Get cos and sin from the cache.");
m.def("get_cos_and_sin", &get_cos_and_sin, "Get cos and sin from the cache.");
}
4 changes: 3 additions & 1 deletion tests/test_infer/test_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def check_inference_engine(use_engine=False, prompt_template=None):
top_k = 50

if use_engine:
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32")
inference_config = InferenceConfig(
max_output_len=output_len, prompt_template=prompt_template, use_cuda_kernel=True, dtype="fp32"
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_infer/test_ops/cuda/test_silu_and_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def test_silu_and_mul(SHAPE_X, SHAPE_Y, SHAPE_Z, dtype):
act_out = torch.nn.functional.silu(ref_input[0], inplace=True)
ref_out = act_out * ref_input[1]

origin_out = inference_ops.silu_and_mul(origin_input)
origin_out = torch.empty_like(ref_out)
inference_ops.silu_and_mul(origin_input, origin_out)

if dtype == torch.float32:
assert torch.allclose(origin_out, ref_out, atol=1e-5, rtol=1e-5)
Expand Down

0 comments on commit 57a9574

Please sign in to comment.