Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port of PR 1810: Fixes from OSDI cam ready runs #9

Merged
merged 8 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ celerybeat.pid
# Environments
.env
.venv
env/
env*
env_sarathi/
env_flashinfer/
env_flashinfer_2/
Expand Down Expand Up @@ -192,6 +192,7 @@ benchmark_output_old
benchmark_output_old_1
offline_inference_output
capacity_search_output*
capacity_latency_trends*
isoqps_output*
profiling_output
env_vidur
Expand Down
27 changes: 10 additions & 17 deletions examples/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,20 @@
output_dir = f"{BASE_OUTPUT_DIR}/{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"

llm_engine = LLMEngine.from_engine_args(
# model="internlm/internlm-20b",
# model="mistralai/Mistral-7B-Instruct-v0.2",
# model="Qwen/Qwen-72B",
# model="01-ai/Yi-34B",
model="meta-llama/Llama-2-7b-hf",
# model="meta-llama/Llama-2-70b-chat-hf",
# model="tiiuae/falcon-40b",
# model="tiiuae/falcon-180B",
# model="codellama/CodeLlama-34b-Instruct-hf",
# scheduler config
max_num_seqs=128,
# scheduler_type="vllm",
# sarathi scheduler config
scheduler_type="sarathi",
chunk_size=100,
# parallel config
tensor_parallel_size=8,
pipeline_parallel_size=1,
tensor_parallel_size=4,
pipeline_parallel_size=2,
trust_remote_code=True,
max_model_len=4096,
attention_backend="FLASHINFER_UNPAGED"
# scheduler config
scheduler_type="sarathi",
chunk_size=100,
max_num_seqs=4,
# metrics config
write_metrics=False,
output_dir=output_dir,
enable_chrome_trace=True,
)


Expand Down
12 changes: 10 additions & 2 deletions sarathi/benchmark/capacity_search/capacity_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def search(self):
left = 0
right = self.job_config.start_qps * 2
qps = 0
last_qps = 0
max_qps_under_sla = None
min_qps_over_sla = 2**32

Expand All @@ -178,6 +179,13 @@ def search(self):
# round to 2 decimal places
qps = round(qps, 2)

if qps == last_qps:
break

last_qps = qps

print(f"Searching between {left} and {right} - qps: {qps}", flush=True)

is_under_sla, scheduling_delay, tbt, run_id = self.is_under_sla(
qps)

Expand Down Expand Up @@ -218,8 +226,8 @@ def search(self):
return {}

logger.info(
f"Max QPS under SLO for {self.job_config.get_human_readable_name()}: "
f"{max_qps_under_sla}, Scheduling delay: {scheduling_delay_at_max_qps}, TBT: {tbt_at_max_qps}",
f"Max QPS under SLO for {self.job_config.get_human_readable_name()} - "
f"QPS: {max_qps_under_sla}, Scheduling delay: {scheduling_delay_at_max_qps}, TBT: {tbt_at_max_qps}",
flush=True,
)
best_run = wandb.Api().run(f"{self.args.wandb_project}/{best_run_id}")
Expand Down
9 changes: 2 additions & 7 deletions sarathi/benchmark/config/default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,8 @@ cluster:

model:
name: meta-llama/Meta-Llama-3-8B
# name: meta-llama/Llama-2-7b-hf
# name: 01-ai/Yi-34B-200K
# name: codellama/CodeLlama-34b-Instruct-hf
# name: Qwen/Qwen-72B
# name: tiiuae/falcon-180B
tensor_parallel_degree: 8
pipeline_parallel_degree: 1
tensor_parallel_degree: 4
pipeline_parallel_degree: 2
max_model_len: 4096
load_format: dummy
attention_backend: flash_attention
Expand Down
34 changes: 23 additions & 11 deletions sarathi/model_executor/attention/flash_attention_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
from vllm_flash_attn import flash_attn_with_kvcache
from typing import List, Optional, Tuple

from sarathi.logger import init_logger
from sarathi.config import ModelConfig, ParallelConfig
from sarathi.core.datatypes.sequence import SequenceMetadata
from sarathi.metrics.constants import OperationMetrics
from sarathi.model_executor.attention.base_attention_wrapper import BaseAttentionWrapper

logger = init_logger(__name__)


class FlashAttentionWrapper(BaseAttentionWrapper):
_inst = None
Expand Down Expand Up @@ -211,17 +214,26 @@ def forward(
-1, 1, self.num_kv_heads, self.head_dim)

with self.get_timer(OperationMetrics.ATTN_DECODE, layer_id):
decode_output = flash_attn_with_kvcache(
decode_query,
kv_cache[0], # k_cache,
kv_cache[1], # v_cache,
decode_key,
decode_value,
cache_seqlens=self.decode_cache_len,
block_table=self.decode_block_table,
softmax_scale=softmax_scale,
causal=True,
)
try:
decode_output = flash_attn_with_kvcache(
decode_query,
kv_cache[0], # k_cache,
kv_cache[1], # v_cache,
decode_key,
decode_value,
cache_seqlens=self.decode_cache_len,
block_table=self.decode_block_table,
softmax_scale=softmax_scale,
causal=True,
)
except RuntimeError as e:
if "If key is supplied, it must have seqlen <= the seqlen of the KV cache" in str(e):
logger.warning(
"Ran into transient error with flash attention: Key length is greater than the cache length. Skipping the attention computation."
)
return output
else:
raise e

with self.get_timer(OperationMetrics.ATTN_OUTPUT_RESHAPE, layer_id):
# flatten the seq_output and copy it to the output tensor
Expand Down
Loading