diff --git a/.gitignore b/.gitignore index 73cb365..ad0cd0c 100644 --- a/.gitignore +++ b/.gitignore @@ -122,7 +122,7 @@ celerybeat.pid # Environments .env .venv -env/ +env* env_sarathi/ env_flashinfer/ env_flashinfer_2/ @@ -192,6 +192,7 @@ benchmark_output_old benchmark_output_old_1 offline_inference_output capacity_search_output* +capacity_latency_trends* isoqps_output* profiling_output env_vidur diff --git a/examples/offline_inference.py b/examples/offline_inference.py index ba1642c..3c96070 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -24,27 +24,20 @@ output_dir = f"{BASE_OUTPUT_DIR}/{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" llm_engine = LLMEngine.from_engine_args( - # model="internlm/internlm-20b", - # model="mistralai/Mistral-7B-Instruct-v0.2", - # model="Qwen/Qwen-72B", - # model="01-ai/Yi-34B", model="meta-llama/Llama-2-7b-hf", - # model="meta-llama/Llama-2-70b-chat-hf", - # model="tiiuae/falcon-40b", - # model="tiiuae/falcon-180B", - # model="codellama/CodeLlama-34b-Instruct-hf", - # scheduler config - max_num_seqs=128, - # scheduler_type="vllm", - # sarathi scheduler config - scheduler_type="sarathi", - chunk_size=100, # parallel config - tensor_parallel_size=8, - pipeline_parallel_size=1, + tensor_parallel_size=4, + pipeline_parallel_size=2, trust_remote_code=True, max_model_len=4096, - attention_backend="FLASHINFER_UNPAGED" + # scheduler config + scheduler_type="sarathi", + chunk_size=100, + max_num_seqs=4, + # metrics config + write_metrics=False, + output_dir=output_dir, + enable_chrome_trace=True, ) diff --git a/sarathi/benchmark/capacity_search/capacity_search.py b/sarathi/benchmark/capacity_search/capacity_search.py index fad8956..eaca3ba 100644 --- a/sarathi/benchmark/capacity_search/capacity_search.py +++ b/sarathi/benchmark/capacity_search/capacity_search.py @@ -159,6 +159,7 @@ def search(self): left = 0 right = self.job_config.start_qps * 2 qps = 0 + last_qps = 0 max_qps_under_sla = None min_qps_over_sla = 2**32 @@ -178,6 +179,13 @@ def search(self): # round to 2 decimal places qps = round(qps, 2) + if qps == last_qps: + break + + last_qps = qps + + print(f"Searching between {left} and {right} - qps: {qps}", flush=True) + is_under_sla, scheduling_delay, tbt, run_id = self.is_under_sla( qps) @@ -218,8 +226,8 @@ def search(self): return {} logger.info( - f"Max QPS under SLO for {self.job_config.get_human_readable_name()}: " - f"{max_qps_under_sla}, Scheduling delay: {scheduling_delay_at_max_qps}, TBT: {tbt_at_max_qps}", + f"Max QPS under SLO for {self.job_config.get_human_readable_name()} - " + f"QPS: {max_qps_under_sla}, Scheduling delay: {scheduling_delay_at_max_qps}, TBT: {tbt_at_max_qps}", flush=True, ) best_run = wandb.Api().run(f"{self.args.wandb_project}/{best_run_id}") diff --git a/sarathi/benchmark/config/default.yml b/sarathi/benchmark/config/default.yml index 14753d0..47c3438 100644 --- a/sarathi/benchmark/config/default.yml +++ b/sarathi/benchmark/config/default.yml @@ -14,13 +14,8 @@ cluster: model: name: meta-llama/Meta-Llama-3-8B - # name: meta-llama/Llama-2-7b-hf - # name: 01-ai/Yi-34B-200K - # name: codellama/CodeLlama-34b-Instruct-hf - # name: Qwen/Qwen-72B - # name: tiiuae/falcon-180B - tensor_parallel_degree: 8 - pipeline_parallel_degree: 1 + tensor_parallel_degree: 4 + pipeline_parallel_degree: 2 max_model_len: 4096 load_format: dummy attention_backend: flash_attention diff --git a/sarathi/model_executor/attention/flash_attention_wrapper.py b/sarathi/model_executor/attention/flash_attention_wrapper.py index 114d3db..2a5cfcf 100644 --- a/sarathi/model_executor/attention/flash_attention_wrapper.py +++ b/sarathi/model_executor/attention/flash_attention_wrapper.py @@ -3,11 +3,14 @@ from vllm_flash_attn import flash_attn_with_kvcache from typing import List, Optional, Tuple +from sarathi.logger import init_logger from sarathi.config import ModelConfig, ParallelConfig from sarathi.core.datatypes.sequence import SequenceMetadata from sarathi.metrics.constants import OperationMetrics from sarathi.model_executor.attention.base_attention_wrapper import BaseAttentionWrapper +logger = init_logger(__name__) + class FlashAttentionWrapper(BaseAttentionWrapper): _inst = None @@ -211,17 +214,26 @@ def forward( -1, 1, self.num_kv_heads, self.head_dim) with self.get_timer(OperationMetrics.ATTN_DECODE, layer_id): - decode_output = flash_attn_with_kvcache( - decode_query, - kv_cache[0], # k_cache, - kv_cache[1], # v_cache, - decode_key, - decode_value, - cache_seqlens=self.decode_cache_len, - block_table=self.decode_block_table, - softmax_scale=softmax_scale, - causal=True, - ) + try: + decode_output = flash_attn_with_kvcache( + decode_query, + kv_cache[0], # k_cache, + kv_cache[1], # v_cache, + decode_key, + decode_value, + cache_seqlens=self.decode_cache_len, + block_table=self.decode_block_table, + softmax_scale=softmax_scale, + causal=True, + ) + except RuntimeError as e: + if "If key is supplied, it must have seqlen <= the seqlen of the KV cache" in str(e): + logger.warning( + "Ran into transient error with flash attention: Key length is greater than the cache length. Skipping the attention computation." + ) + return output + else: + raise e with self.get_timer(OperationMetrics.ATTN_OUTPUT_RESHAPE, layer_id): # flatten the seq_output and copy it to the output tensor