diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index a2388121bbf4..c4943ce08479 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -125,7 +125,7 @@ def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor] 0, max_num_blocks ) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len block_tables = torch.from_numpy(self.graph_block_tables).cuda() - output_tensor = torch.zeros( + output_tensor = torch.empty( (max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device ) fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor @@ -371,13 +371,13 @@ def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, sequence_lengths = batch.get_sequence_lengths() if batch.is_prompts: - output_tensor = torch.zeros( - (sequence_lengths.sum().item(), batch.num_heads * batch.head_dim), + output_tensor = torch.empty( + (input_ids.size(0), batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device, ) else: - output_tensor = torch.zeros( + output_tensor = torch.empty( (batch.current_batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device ) diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 9969c6786eab..995dac87916c 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -297,11 +297,13 @@ def search_tokens(self, generation_config: GenerationConfig, logits): Sample tokens for finished requests. """ # do logit processor - # NOTE: need to decide the granularity to process logits (sequence or batch) - config_dict = generation_config.to_dict() - for type in ["top_k", "top_p", "min_p"]: - if type in config_dict and config_dict[type] is not None: - logits = logit_processor(type, logits, config_dict[type]) + top_p = generation_config.top_p + top_k = generation_config.top_k + + if top_k: + logits = logit_processor("top_k", logits, top_k) + if top_p: + logits = logit_processor("top_p", logits, top_p) # calculate probs probs = torch.softmax(logits, dim=-1, dtype=torch.float) diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py index 7d435d59ceb8..548b468bda14 100644 --- a/colossalai/inference/kv_cache/kvcache_manager.py +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -138,7 +138,7 @@ def get_block_table_kv_ptrs(self, block_table: torch.Tensor, layer_id: int) -> T """Get the key and value pointers of physical caches (of specific layer) corresponding to logical cache blocks indicated by the block table.""" k_ptrs = [] v_ptrs = [] - for block_id in block_table: + for block_id in block_table.tolist(): if block_id >= 0: block: CacheBlock = self._cache_blocks[block_id] k_ptrs.append(block.k_ptrs[layer_id]) @@ -223,46 +223,58 @@ def allocate_context_from_block_tables(self, block_tables: torch.Tensor, context self._block_states_cum[:-num_blocks_required], out=self._block_finder[num_blocks_required - 1 :], ) + end_indexes = torch.nonzero(self._block_finder == num_blocks_required, as_tuple=False).view(-1) + + block_tables_list = block_tables.tolist() + blocks_required_list = blocks_required.tolist() + if end_indexes.numel() > 0: # contiguous cache exists end_idx = end_indexes[0].item() + 1 # open interval start_idx = end_idx - num_blocks_required # closed interval - alloc_block_ids = torch.arange(start_idx, end_idx) + alloc_block_ids = torch.arange(start_idx, end_idx, device=block_tables.device) + for i in range(bsz): - curr_required = blocks_required[i] - block_tables[i, :curr_required] = torch.arange( - start_idx, start_idx + curr_required, device=block_tables.device - ) + curr_required = blocks_required_list[i] + for j in range(curr_required): + block_tables_list[i][j] = start_idx + j start_idx += curr_required else: # non-contiguous cache available_block_ids = torch.nonzero(self._block_states > 0).view(-1) alloc_block_ids = available_block_ids[:num_blocks_required] alloc_block_ids = alloc_block_ids.to(dtype=block_tables.dtype, device=block_tables.device) + alloc_block_ids_list = alloc_block_ids.tolist() start_idx = 0 for i in range(bsz): - curr_required = blocks_required[i] - block_tables[i, :curr_required] = alloc_block_ids[start_idx, start_idx + curr_required] + curr_required = blocks_required_list[i] + for j in range(curr_required): + block_tables_list[i][j] = alloc_block_ids_list[start_idx + j] start_idx += curr_required + block_tables.copy_(torch.tensor(block_tables_list)) + # Update cache blocks self._block_states[alloc_block_ids] = 0 self._available_blocks -= num_blocks_required last_block_locs = torch.cumsum(blocks_required, dim=0) - 1 - last_block_locs = last_block_locs.to(device=alloc_block_ids.device) - for i, block_id in enumerate(alloc_block_ids[last_block_locs]): + last_block_ids = alloc_block_ids[last_block_locs].tolist() + alloc_block_ids = alloc_block_ids.tolist() + context_lengths = context_lengths.tolist() + + for i, block_id in enumerate(last_block_ids): block: CacheBlock = self._cache_blocks[block_id] block.add_ref() self._allocate_on_block( block, block.block_size if context_lengths[i] % block.block_size == 0 - else context_lengths[i].item() % block.block_size, + else context_lengths[i] % block.block_size, ) for block_id in alloc_block_ids: - if block_id in alloc_block_ids[last_block_locs]: + if block_id in last_block_ids: continue block: CacheBlock = self._cache_blocks[block_id] block.add_ref() @@ -336,7 +348,7 @@ def allocate_tokens_from_block_tables( dtype=block_tables.dtype, device=block_tables.device ) - for block_id in alloc_block_ids: + for block_id in alloc_block_ids.tolist(): block: CacheBlock = self._cache_blocks[block_id] block.add_ref() self._block_states[block_id] = 0 @@ -344,7 +356,7 @@ def allocate_tokens_from_block_tables( block_tables[seqs_req_new_blocks, alloc_local_block_indexes[seqs_req_new_blocks]] = alloc_block_ids block_global_ids = block_tables[torch.arange(0, bsz), alloc_local_block_indexes] - for block_id in block_global_ids: + for block_id in block_global_ids.tolist(): self._allocate_on_block(self._cache_blocks[block_id], 1) return seqs_to_recycle diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index c5b61385f822..203ee4cc0b50 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -121,6 +121,9 @@ def llama_model_forward( sm_scale = 1.0 / (inputmetadata.head_dim**0.5) norm_output = torch.empty_like(hidden_states) + silu_and_mul_output = torch.empty( + hidden_states.size(0), self.config.intermediate_size, dtype=hidden_states.dtype, device=hidden_states.device + ) residual = None for layer_id, decoder_layer in enumerate(self.layers): @@ -137,6 +140,7 @@ def llama_model_forward( kv_seq_len=kv_seq_len, output_tensor=output_tensor, norm_output=norm_output, + silu_and_mul_output=silu_and_mul_output, sm_scale=sm_scale, use_cuda_kernel=use_cuda_kernel, cu_seqlens=cu_seqlens, @@ -167,6 +171,7 @@ def llama_decoder_layer_forward( kv_seq_len: int = 0, output_tensor: torch.Tensor = None, norm_output: torch.Tensor = None, + silu_and_mul_output: torch.Tensor = None, sm_scale: int = None, use_cuda_kernel: bool = True, cu_seqlens: torch.Tensor = None, @@ -216,7 +221,7 @@ def llama_decoder_layer_forward( # Fully Connected hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual, use_cuda_kernel) - hidden_states = self.mlp(hidden_states) + hidden_states = self.mlp(hidden_states, silu_and_mul_output) return hidden_states, residual @@ -481,12 +486,12 @@ def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP: return mlp_layer - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, silu_and_mul_output: torch.Tensor) -> torch.Tensor: """ Args: hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. """ hidden_states = hidden_states.expand(2, -1, -1) gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight) - act_out = inference_ops.silu_and_mul(gate_up_proj_out) - return torch.mm(act_out, self.down_proj_weight) + inference_ops.silu_and_mul(gate_up_proj_out, silu_and_mul_output) + return torch.mm(silu_and_mul_output, self.down_proj_weight) diff --git a/examples/inference/benchmark_llama.py b/examples/inference/benchmark_llama.py index 448a84c6fa0e..8128ce9f3f76 100644 --- a/examples/inference/benchmark_llama.py +++ b/examples/inference/benchmark_llama.py @@ -117,6 +117,7 @@ def benchmark_inference(args): max_output_len=args.output_len, prefill_ratio=1.2, block_size=32, + use_cuda_kernel=True, ) engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) elif args.mode == "vllm": diff --git a/extensions/csrc/cuda/activation_kernel.cu b/extensions/csrc/cuda/activation_kernel.cu index 372b303875cb..83dcfc582794 100644 --- a/extensions/csrc/cuda/activation_kernel.cu +++ b/extensions/csrc/cuda/activation_kernel.cu @@ -34,18 +34,8 @@ __global__ void act_and_mul_kernel( // Note(LiuYang):This func is designed for calculation mode like // silu(x[:half_1stdim]) * (x[half_1stdim:]) -torch::Tensor silu_and_mul(const torch::Tensor& ins) +void silu_and_mul(const torch::Tensor& ins, torch::Tensor& outs) { - // Note(LiuYang): According to torch doc, vec() may cost a lot, but I did't find a better api - // to manipulate ins_shape which is IntArrayRef - auto ins_shape = ins.sizes().vec(); - - ins_shape[0] = ins_shape[0]/2; - if (ins_shape[0] == 1) { - ins_shape.erase(ins_shape.begin()); - } - auto outs = torch::zeros(ins_shape,ins.options()); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Note(Liuyang): numel of ins must be divisible by 2 @@ -71,5 +61,4 @@ torch::Tensor silu_and_mul(const torch::Tensor& ins) );) AT_CUDA_CHECK(cudaGetLastError()); - return outs; } diff --git a/extensions/csrc/cuda/pybind/inference.cpp b/extensions/csrc/cuda/pybind/inference.cpp index 45745e6a3e29..ac5b5fe24553 100644 --- a/extensions/csrc/cuda/pybind/inference.cpp +++ b/extensions/csrc/cuda/pybind/inference.cpp @@ -39,7 +39,7 @@ void rotary_embedding_and_cache_copy( torch::Tensor& block_tables, // [batch_size, max_seq_len] bool high_precision); -torch::Tensor silu_and_mul(const torch::Tensor& ins); +void silu_and_mul(const torch::Tensor& ins, torch::Tensor& outs); void rms_layernorm(torch::Tensor& out, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size] @@ -80,6 +80,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_add_rms_layernorm", &fused_add_rms_layernorm, "In-place fused Add and RMS Normalization."); - m.def("get_cos_and_sin", &get_cos_and_sin, - "Get cos and sin from the cache."); + m.def("get_cos_and_sin", &get_cos_and_sin, "Get cos and sin from the cache."); } diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index 25b2c2f4318a..c29a06e38f47 100644 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -40,7 +40,9 @@ def check_inference_engine(use_engine=False, prompt_template=None): top_k = 50 if use_engine: - inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32") + inference_config = InferenceConfig( + max_output_len=output_len, prompt_template=prompt_template, use_cuda_kernel=True, dtype="fp32" + ) inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) assert inference_engine.generation_config.max_new_tokens == output_len inference_engine.add_request(prompts=inputs) diff --git a/tests/test_infer/test_ops/cuda/test_silu_and_mul.py b/tests/test_infer/test_ops/cuda/test_silu_and_mul.py index ced2db7ca048..e606c63ad7eb 100644 --- a/tests/test_infer/test_ops/cuda/test_silu_and_mul.py +++ b/tests/test_infer/test_ops/cuda/test_silu_and_mul.py @@ -20,7 +20,8 @@ def test_silu_and_mul(SHAPE_X, SHAPE_Y, SHAPE_Z, dtype): act_out = torch.nn.functional.silu(ref_input[0], inplace=True) ref_out = act_out * ref_input[1] - origin_out = inference_ops.silu_and_mul(origin_input) + origin_out = torch.empty_like(ref_out) + inference_ops.silu_and_mul(origin_input, origin_out) if dtype == torch.float32: assert torch.allclose(origin_out, ref_out, atol=1e-5, rtol=1e-5)