Skip to content

Commit

Permalink
Update model definition to support Flash-Decoding (#177)
Browse files Browse the repository at this point in the history
* test stub

* wip

* wip

* wip

* compiled

* wip

* fix

* fix

* wip, decode with flash decoding works

* all work

* add paged_kv_cache_type option

* read kv_type from artifact

* black

* refactor attention backend

* minor clean up
  • Loading branch information
masahi authored Jan 30, 2024
1 parent e1bd866 commit 253da78
Show file tree
Hide file tree
Showing 3 changed files with 494 additions and 181 deletions.
150 changes: 117 additions & 33 deletions examples/python/run_llama_batched_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,29 @@


class KVCache:
def __init__(self, num_blocks, block_size, num_layers, num_heads, head_size, disco_session):
if disco_session:
init_cache_func = disco_session.get_global_func("tvm.contrib.vllm.allocate_kv_cache")
else:
init_cache_func = tvm.get_global_func("tvm.contrib.vllm.allocate_kv_cache")

def __init__(self, num_blocks, block_size, num_layers, num_heads, head_size, init_cache_func):
self.cache = init_cache_func(head_size, num_layers, num_heads, block_size, num_blocks)

self.block_tables = defaultdict(list)
self.slot_mappings = defaultdict(list)
self.block_size = block_size


class CacheManager:
block_size: int = 16

def __init__(
self, num_blocks, num_layers, num_heads, head_size, disco_session=None, sliding_window=None
self,
num_blocks,
block_size,
num_layers,
num_heads,
head_size,
init_cache_func,
sliding_window=None,
):
self.block_size = block_size
self.num_blocks = num_blocks
self.free_blocks = list(range(num_blocks))
self.kv_cache = KVCache(
num_blocks, self.block_size, num_layers, num_heads, head_size, disco_session
num_blocks, self.block_size, num_layers, num_heads, head_size, init_cache_func
)

if sliding_window:
Expand Down Expand Up @@ -172,6 +172,7 @@ def _prepare_inputs(
sliding_window,
dev,
is_prefill,
query_token_len=1,
):
block_tables = []
seq_lens = []
Expand Down Expand Up @@ -201,13 +202,16 @@ def _prepare_inputs(
start_idx += prompt_len

else:
input_ids.append(token_ids[-1])
pos = len(token_ids) - 1
positions.append(pos)
input_ids += token_ids[-query_token_len:]

for i in range(query_token_len):
positions.append(len(token_ids) - (query_token_len - i))

slot_mapping += all_slot_mappings[request_id][-query_token_len:]

block_table = all_block_tables[request_id]
max_num_blocks_per_seq = max(max_num_blocks_per_seq, len(block_table))
block_tables.append(block_table)
slot_mapping.append(all_slot_mappings[request_id][-1])

if sliding_window:
seq_lens.append(min(len(token_ids), sliding_window))
Expand Down Expand Up @@ -316,7 +320,15 @@ def _prepare_eval_queries(

class Model:
def __init__(
self, artifact_path, model_name, quant, vocab_size, num_shards, dev, sliding_window
self,
artifact_path,
model_name,
quant,
vocab_size,
num_shards,
dev,
sliding_window,
block_size,
):
self.mod, self.params, self.disco_session = get_tvm_model(
artifact_path, model_name, quant, num_shards, dev
Expand All @@ -326,7 +338,7 @@ def __init__(
self.sliding_window = sliding_window

if sliding_window:
self.block_sliding_window = sliding_window // CacheManager.block_size
self.block_sliding_window = sliding_window // block_size
else:
self.block_sliding_window = None

Expand Down Expand Up @@ -409,6 +421,15 @@ def generate(
]


def get_paged_kv_cache_type(model_artifact_path):
config_file_path = os.path.join(model_artifact_path, "build_config.json")
assert os.path.exists(config_file_path)

with open(config_file_path, mode="rt", encoding="utf-8") as f:
build_cfg = json.load(f)
return build_cfg["paged_kv_cache_type"]


def parse_args():
# Example
# python build.py --model vicuna-v1-7b --quantization q4f16_ft --use-cache=0 --max-seq-len 768 --enable-batching --use-vllm-attention
Expand Down Expand Up @@ -444,6 +465,18 @@ def run(args):
with open(os.path.join(model_path, "config.json"), encoding="utf-8") as i_f:
config = LlamaConfig(**json.load(i_f))

kv_type = get_paged_kv_cache_type(args.artifact_path)
use_flash_decoding = kv_type == "flash-decoding"

if use_flash_decoding:
allocate_func_name = "tvm.contrib.flash_attn.allocate_kv_cache"
block_size = 256
num_blocks = 30
else:
allocate_func_name = "tvm.contrib.vllm.allocate_kv_cache"
block_size = 16
num_blocks = 500

model = Model(
artifact_path,
model_name,
Expand All @@ -452,20 +485,26 @@ def run(args):
args.num_shards,
dev,
config.sliding_window,
block_size,
)

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=False)

num_kv_heads = config.get_num_key_value_heads() // args.num_shards
head_size = config.hidden_size // config.num_attention_heads
num_blocks = 500

if model.disco_session:
init_cache_func = model.disco_session.get_global_func(allocate_func_name)
else:
init_cache_func = tvm.get_global_func(allocate_func_name)

cache_manager = CacheManager(
num_blocks,
block_size,
config.num_hidden_layers,
num_kv_heads,
head_size,
model.disco_session,
init_cache_func,
sliding_window=config.sliding_window,
)
cache = cache_manager.get()
Expand Down Expand Up @@ -516,8 +555,28 @@ def run(args):
for p, g in zip(prompts, generated):
print("Prompt = '{}', generated text = '{}'".format(p, g))

query_token_lens = [4, 3, 5, 2]
if model.disco_session:
return

def verify_logits(logits, query_token_lens):
assert logits.shape[0] == sum(query_token_lens)

logits_offset = 0

for request_id, query_token_len in zip(request_ids, query_token_lens):
for i in range(query_token_len - 1):
# requests[request_id].token_ids[-query_token_len:] are the "ground truth" tokens.
# Doing argmax over multi-timestep logits computed in parallel should yield the same
# tokens at the corresponding positions.
past_tokens = requests[request_id].token_ids[:-query_token_len]
assert (
np.argmax(logits[logits_offset + i])
== requests[request_id].token_ids[len(past_tokens) + i + 1]
)

logits_offset += query_token_len

query_token_lens = [4, 3, 5, 2]
eval_query_requests = []

for request_id, query_token_len in zip(request_ids, query_token_lens):
Expand Down Expand Up @@ -552,22 +611,47 @@ def run(args):
model.params,
)[0].numpy()

assert logits.shape[0] == sum(query_token_lens)
verify_logits(logits, query_token_lens)

logits_offset = 0
if not use_flash_decoding:
return

for request_id, query_token_len in zip(request_ids, query_token_lens):
for i in range(query_token_len - 1):
# requests[request_id].token_ids[-query_token_len:] are the "ground truth" tokens.
# Doing argmax over multi-timestep logits computed in parallel should yield the same
# tokens at the corresponding positions.
past_tokens = requests[request_id].token_ids[:-query_token_len]
assert (
np.argmax(logits[logits_offset + i])
== requests[request_id].token_ids[len(past_tokens) + i + 1]
)
query_token_lens = [3, 3, 3, 3]
decode_multi_query_requests = requests
query_len = query_token_lens[0]

(
input_ids,
positions,
seq_lens,
slot_mapping,
_,
block_tables,
) = _prepare_inputs(
decode_multi_query_requests,
cache.slot_mappings,
cache.block_tables,
model.sliding_window,
model.dev,
False, # is_prefill
query_len,
)

input_ids = tvm.nd.array(np.reshape(input_ids.numpy(), [-1, query_len]), dev)

logits = model.mod["decode_multi_query"](
input_ids,
positions,
seq_lens,
cache.cache,
slot_mapping,
block_tables,
model.params,
)[0].numpy()

logits = np.reshape(logits, (-1, logits.shape[-1]))

logits_offset += query_token_len
verify_logits(logits, query_token_lens)


if __name__ == "__main__":
Expand Down
9 changes: 9 additions & 0 deletions mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ class BuildArgs:
"action": "store_true",
},
)
# TODO(masahi): Remove the use of this option with paged_kv_cache_type
use_vllm_attention: bool = field(
default=False,
metadata={
Expand All @@ -402,6 +403,10 @@ class BuildArgs:
"action": "store_true",
},
)
paged_kv_cache_type: str = field(
default="vllm",
metadata={"help": "The type of paged KV cache, either vllm or flash-decoding"},
)

@property
def convert_weight_only(self):
Expand Down Expand Up @@ -595,6 +600,9 @@ def mod_transform_before_build(
model_names.append("evaluate")
model_names.append("evaluate_multi_query")

if args.paged_kv_cache_type == "flash-decoding":
model_names.append("decode_multi_query")

if args.sep_embed:
model_names = ["embed", "prefill_with_embed"] + model_names[1:]
if args.enable_batching:
Expand Down Expand Up @@ -706,6 +714,7 @@ def dump_build_config(
config: Dict[str, Any] = {
"num_shards": args.num_shards,
"quantization": args.quantization.name,
"paged_kv_cache_type": args.paged_kv_cache_type,
"library_name": args.lib_name,
"build_options": str(args)
}
Expand Down
Loading

0 comments on commit 253da78

Please sign in to comment.