From 466b0a65429d339a1c004c5991749e6f9cb1230b Mon Sep 17 00:00:00 2001 From: Alfred Gui Date: Mon, 1 Jul 2024 22:48:56 -0400 Subject: [PATCH] Add the batch concatenation functionality for flashinfer server (#43) * refactor flashinfer causal lm * modify test_local_api * fixes * fixes * lint --- server/examples/test_local_api.py | 32 +- server/text_generation_server/cache.py | 3 + .../layers/flashinfer_attention.py | 6 +- .../models_flashinfer/flashinfer_causal_lm.py | 411 ++++++++---------- .../server_flashinfer.py | 72 +-- .../utils/cache_manager_flashinfer.py | 133 ++---- 6 files changed, 271 insertions(+), 386 deletions(-) diff --git a/server/examples/test_local_api.py b/server/examples/test_local_api.py index 1b35e1de..43756946 100644 --- a/server/examples/test_local_api.py +++ b/server/examples/test_local_api.py @@ -3,7 +3,9 @@ from text_generation_server.models_flashinfer.flashinfer_llama import FlashinferLlama from text_generation_server.models_flashinfer.flashinfer_gemma import FlashinferGemma from text_generation_server.models_flashinfer.flashinfer_qwen2 import FlashinferQwen2 -from text_generation_server.models_flashinfer.flashinfer_chatglm import FlashinferChatGLM +from text_generation_server.models_flashinfer.flashinfer_chatglm import ( + FlashinferChatGLM, +) import sys try: @@ -29,13 +31,13 @@ # test = "gemma" # test = "llama-3" # test = 'llama-3-70' - # test = "llama-2" + test = "gemma" # test = 'mistral' # test = 'qwen1.5-7' # test = 'qwen1.5-1.8' # test = 'qwen1.5-70' # test = 'qwen2-7' - test = 'chatglm4' + # test = "chatglm4" print("Testing " + test) # Load demo inputs @@ -274,10 +276,8 @@ def make_input(lora_id, lora_or_base, id=0, promptOverride=None): promptOverride="给我讲个故事", ), ] - service = FlashinferQwen2( - model_id="Qwen/Qwen2-7B-Instruct", trust_remote_code=True - ) - + service = FlashinferQwen2(model_id="Qwen/Qwen2-7B-Instruct", trust_remote_code=True) + elif test == "chatglm4": # Todo: chatglm4-9b lora adapter requests = [ @@ -288,25 +288,23 @@ def make_input(lora_id, lora_or_base, id=0, promptOverride=None): promptOverride="给我讲个故事", ), ] - service = FlashinferChatGLM( - model_id="THUDM/glm-4-9b-chat", trust_remote_code=True - ) + service = FlashinferChatGLM(model_id="THUDM/glm-4-9b-chat", trust_remote_code=True) print(service.get_lora_adapters()) tokenizer = service.tokenizer batch = generate_pb2.Batch(id=0, requests=requests, size=len(requests)) -pb_batch = FlashinferBatch.from_pb( - batch, tokenizer, torch.float16, torch.device("cuda") -) - -# Add input batch to model service -ids = service.add_request(pb_batch) display_results = {} # Iterative generation: each step generates a token for each input in the batch +isPrefill = True while True: - generations, _, _ = service.generate_token(FlashinferBatch.Empty(batch.id)) + if isPrefill: + generations, next_batch, _ = service.prefill_batch(batch) + isPrefill = False + else: + generations, next_batch, _, _ = service.decode_batch([next_batch.to_pb()]) + for gen in generations: if gen.prefill_tokens: display_results[gen.request_id] = [ diff --git a/server/text_generation_server/cache.py b/server/text_generation_server/cache.py index 4504733e..117f8499 100644 --- a/server/text_generation_server/cache.py +++ b/server/text_generation_server/cache.py @@ -11,6 +11,9 @@ class Cache: def __init__(self): self.cache: Dict[int, B] = {} + def get_all_values(self): + return self.cache.values() + def pop(self, batch_id: int) -> Optional[B]: return self.cache.pop(batch_id, None) diff --git a/server/text_generation_server/layers/flashinfer_attention.py b/server/text_generation_server/layers/flashinfer_attention.py index 98380780..73245680 100644 --- a/server/text_generation_server/layers/flashinfer_attention.py +++ b/server/text_generation_server/layers/flashinfer_attention.py @@ -42,7 +42,7 @@ def __init__( 32 * 1024 * 1024, dtype=torch.int8, device=torch.cuda.current_device() ) self.page_size = 16 - + self.group_size = self.num_attention_heads // self.num_key_value_heads def computeAttention( @@ -186,7 +186,9 @@ def _batchDecode( if self.group_size in [7, 16]: decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer=self._workspace_buffer, kv_layout="NHD", use_tensor_cores=True + workspace_buffer=self._workspace_buffer, + kv_layout="NHD", + use_tensor_cores=True, ) else: decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( diff --git a/server/text_generation_server/models_flashinfer/flashinfer_causal_lm.py b/server/text_generation_server/models_flashinfer/flashinfer_causal_lm.py index 4c66c8f3..bd5c28f4 100644 --- a/server/text_generation_server/models_flashinfer/flashinfer_causal_lm.py +++ b/server/text_generation_server/models_flashinfer/flashinfer_causal_lm.py @@ -1,14 +1,14 @@ import torch import torch.distributed -from typing import Any, TypedDict, Optional +from typing import Any, Optional from text_generation_server.utils.lora_utils import ModelLoraManager, ModelConfigForLora from text_generation_server.utils.cache_manager_flashinfer import ( - ModelKvCache, + getKvCacheBatchPosition, + KvCacheBatchPosition, KvCachePool, + RequestKvCache, ) from text_generation_server.utils.tokens import ( - StopSequenceCriteria, - StoppingCriteria, FinishReason, ) from text_generation_server.layers.flashinfer_attention import find_padded_head_dim @@ -20,119 +20,24 @@ from opentelemetry import trace from typing import Optional, Tuple, List, Type, Dict from text_generation_server.models import Model -from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.models.types import ( - Batch, Tokens, Generation, GeneratedText, ) -from text_generation_server.utils import ( - NextTokenChooser, - StoppingCriteria, -) from text_generation_server.utils.dist import MEMORY_FRACTION from dataclasses import dataclass +from collections.abc import Iterable +from text_generation_server.cache import Cache tracer = trace.get_tracer(__name__) -class TextGenerationChunk(TypedDict): - index: int - token_id: int - text: str - is_stop: bool - - -@dataclass -class FlashinferBatch(CausalLMBatch): - @classmethod - def Empty(cls, batch_id): - return cls( - batch_id=batch_id, - requests=None, - prefix_offsets=None, - read_offsets=None, - next_token_choosers=None, - stopping_criterias=None, - top_n_tokens=None, - top_n_tokens_tensor=None, - input_ids=None, - requests_idx_mapping=None, - attention_mask=None, - position_ids=None, - past_key_values=None, - all_input_ids=None, - input_lengths=None, - max_input_length=None, - padding_right_offset=None, - max_tokens=None, - ) - - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase = None, - dtype: torch.dtype = None, - device: torch.device = "cuda", - ) -> "CausalLMBatch": - input_ids = [] - next_token_choosers = [] - stopping_criterias = [] - top_n_tokens = [] - prefix_offsets = [] - read_offsets = [] - - # Parse batch - for i, r in enumerate(pb.requests): - prompt = r.inputs - - next_token_choosers.append( - NextTokenChooser.from_pb(r.parameters, device, tokenizer) - ) - stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer - ) - stopping_criterias.append(stopping_criteria) - top_n_tokens.append(r.top_n_tokens) - tokenized_inputs = tokenizer.encode(prompt) - input_len = len(tokenized_inputs) - prefix_offsets.append(input_len - 5) - read_offsets.append(input_len) - input_ids.append(tokenized_inputs) - - top_n_tokens_tensor = torch.tensor( - top_n_tokens, device=device, dtype=torch.int64 - ) - - return cls( - batch_id=pb.id, - requests=pb.requests, - requests_idx_mapping=None, - input_ids=input_ids, - attention_mask=None, - position_ids=None, - past_key_values=None, - all_input_ids=None, - input_lengths=None, - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - max_input_length=None, - padding_right_offset=None, - max_tokens=None, - ) - - class RequestContext: def __init__( self, + request_id: str, input_ids: list[int], - lora_id: str, tokenizer, *, temperature: float, @@ -141,8 +46,12 @@ def __init__( top_k: int, maxlen: int, stop_token_id: int, + is_stopped: bool, + request_kv_cache: RequestKvCache, prefill_logprobs: bool = True, + lora_id: str = "empty", ): + self.request_id = request_id self.temperature = temperature self.repetition_penalty = repetition_penalty self.top_p = top_p @@ -172,6 +81,9 @@ def __init__( self.tokenizer = tokenizer self.prefix_offset = 0 self.read_offset = 0 + self.is_stopped = is_stopped + self.prefill_tokens: Optional[Tokens] = None + self.request_kv_cache = request_kv_cache def get_next_token_id(self, logits: torch.Tensor) -> int: if self.logits_processor: @@ -194,15 +106,34 @@ def get_next_token_id(self, logits: torch.Tensor) -> int: def append_token(self, token_id: int): self.output_ids.append(token_id) - def is_stop(self) -> FinishReason: + def get_stop_reason(self) -> FinishReason: if len(self.output_ids) - self.prompt_len >= self.maxlen: return FinishReason.FINISH_REASON_LENGTH if self.output_ids[-1] == self.stop_token_id: return FinishReason.FINISH_REASON_EOS_TOKEN return None - def is_prefill(self) -> bool: - return len(self.output_ids) == self.prompt_len + +@dataclass(frozen=True) +class FlashinferBatch: + batch_id: int + is_prefill: bool + request_contexts: List[RequestContext] + + def to_pb(self) -> generate_pb2.CachedBatch: + + max_input_length = max([r.prompt_len for r in self.request_contexts]) + max_decode_tokens = max([r.maxlen for r in self.request_contexts]) + max_tokens = len(self.request_contexts) * (max_input_length + max_decode_tokens) + + return generate_pb2.CachedBatch( + id=self.batch_id, + request_ids=[ + request_context.request_id for request_context in self.request_contexts + ], + size=len(self.request_contexts), + max_tokens=max_tokens, + ) class FlashinferLM(Model): @@ -213,11 +144,12 @@ def __init__( config: PretrainedConfig, dtype: torch.dtype, device: torch.device, - lora_ids: List[str] = None, + lora_ids: List[str], ): self.device = device self.dtype = dtype self.model_config = config + self.batch_cache = Cache() if ( torch.cuda.is_available() @@ -267,7 +199,7 @@ def __init__( f" Number of Pages to Allocate: {num_pages_to_allocate}" ) - kvCachePool = KvCachePool( + self.kvCachePool = KvCachePool( max_pages=num_pages_to_allocate, num_layers=self.model_config.num_hidden_layers, num_heads=self.model_config.num_key_value_heads, @@ -277,7 +209,6 @@ def __init__( device=device, ) - self.modelKvCache = ModelKvCache(kvCachePool) self.model_config_for_lora = ModelConfigForLora( num_hidden_layers=config.num_hidden_layers, hidden_size=config.hidden_size, @@ -287,11 +218,7 @@ def __init__( ) self.loraManager = ModelLoraManager(self.model_config_for_lora, dtype) - if lora_ids: - self.loraManager.set_lora_weights( - lora_ids, self.model_config_for_lora or {}, dtype - ) - self.reqctx: dict[int, RequestContext] = {} + self.loraManager.set_lora_weights(lora_ids, self.model_config_for_lora, dtype) super(FlashinferLM, self).__init__( model=model, @@ -301,13 +228,6 @@ def __init__( device=device, ) - def _find_padded_head_dim(self, head_dim): - flashInferDimensions = [64, 128, 256] - for dim in flashInferDimensions: - if head_dim <= dim: - return dim - raise ValueError("The head dimension is too large for FlashInfer") - def load_lora_adapters(self, lora_ids: List[str]): self.loraManager.set_lora_weights( lora_ids, @@ -321,50 +241,99 @@ def remove_lora_adapters(self, lora_ids: list[str] = None): def get_lora_adapters(self): return list(self.loraManager.lora_weights_cpu) - def has_request(self): - return len(self.reqctx) > 0 + def decode_batch( + self, cachedBatchesPb: Iterable[generate_pb2.CachedBatch] + ) -> Tuple[List[Generation], Optional[FlashinferBatch], Tuple[int, int], int]: + start_concat = time.time_ns() + batch = self._convertCachedBatch(cachedBatchesPb) + concat_ns = time.time_ns() - start_concat + generations, next_batch, timings = self.generate_token(batch) + if next_batch: + self.batch_cache.set(next_batch) + return generations, batch, timings, concat_ns + + def prefill_batch( + self, batchPb: generate_pb2.Batch + ) -> Tuple[List[Generation], Optional[FlashinferBatch], Tuple[int, int]]: + batch = self._convertPbBatch(batchPb) + generations, next_batch, timings = self.generate_token(batch) + if next_batch: + self.batch_cache.set(next_batch) + return generations, batch, timings - @property - def batch_type(self) -> Type[FlashinferBatch]: - return FlashinferBatch + def clear_cache(self): + all_batches: List[FlashinferBatch] = self.batch_cache.get_all_values() + for batch in all_batches: + for request_context in batch.request_contexts: + request_context.request_kv_cache.release() + + self.batch_cache.clear() + + def _find_padded_head_dim(self, head_dim): + flashInferDimensions = [64, 128, 256] + for dim in flashInferDimensions: + if head_dim <= dim: + return dim + raise ValueError("The head dimension is too large for FlashInfer") + + def _convertPbBatch(self, batchPb: generate_pb2.Batch) -> FlashinferBatch: + request_contexts = [] + + for request in batchPb.requests: + prompt = request.inputs + input_ids = self.tokenizer.encode(prompt) + parameters = request.parameters + request_context = RequestContext( + request.id, + input_ids, + self.tokenizer, + temperature=parameters.temperature, + repetition_penalty=parameters.repetition_penalty, + top_p=parameters.top_p, + top_k=parameters.top_k, + maxlen=min(request.stopping_parameters.max_new_tokens, 4096), + stop_token_id=self.tokenizer.eos_token_id, + is_stopped=False, + request_kv_cache=RequestKvCache( + self.kvCachePool, + self.kvCachePool.page_len, + len(input_ids), + ), + prefill_logprobs=request.prefill_logprobs, + lora_id=request.lora_id, + ) + + request_contexts.append(request_context) - def decode(self, generated_ids: List[int]) -> str: - return self.tokenizer.decode( - generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + return FlashinferBatch( + batch_id=batchPb.id, is_prefill=True, request_contexts=request_contexts ) - def add_request(self, batch: FlashinferBatch): - ids = [] - for r in range(len(batch.requests)): - id = batch.requests[r].id - # Router sends initial request in each iteration - if id not in self.reqctx: - lora_id = batch.requests[r].lora_id or "empty" - input = batch.input_ids[r] - parameters = batch.requests[r].parameters - stop = batch.requests[r].stopping_parameters - prefill_logprobs = batch.requests[r].prefill_logprobs - - if lora_id not in self.loraManager.lora_weights_cpu: - raise ValueError("Cannot find lora weights", lora_id) - - self.reqctx[id] = RequestContext( - input, - lora_id, - self.tokenizer, - temperature=parameters.temperature, - repetition_penalty=parameters.repetition_penalty, - top_p=parameters.top_p, - top_k=parameters.top_k, - maxlen=min(stop.max_new_tokens, 4096), - stop_token_id=self.tokenizer.eos_token_id, - prefill_logprobs=prefill_logprobs, - ) - ids.append(id) - return ids + def _convertCachedBatch( + self, cachedBatchesPb: Iterable[generate_pb2.CachedBatch] + ) -> FlashinferBatch: + batches: List[FlashinferBatch] = [] + for batch_pb in cachedBatchesPb: + batch = self.batch_cache.pop(batch_pb.id) + if batch is None: + raise ValueError(f"Batch ID {batch_pb.id} not found in cache.") + batches.append(batch) + + if len(batches) == 0: + raise ValueError("All batches are empty") + + request_contexts_combined: List[RequestContext] = [] + for batch in batches: + request_contexts_combined.extend(batch.request_contexts) + + return FlashinferBatch( + batch_id=batches[0].batch_id, + is_prefill=False, + request_contexts=request_contexts_combined, + ) - def warmup(self, batch: FlashinferBatch): - pass + def batch_type(self): + return FlashinferBatch @tracer.start_as_current_span("generate_token") @torch.no_grad() @@ -372,64 +341,48 @@ def generate_token( self, batch: FlashinferBatch ) -> Tuple[List[Generation], Optional[FlashinferBatch], Tuple[int, int]]: start = time.time_ns() + input_ids, lora_ids, lora_lens = [], [], [] + request_kv_caches = [] + for request_context in batch.request_contexts: + if not request_context.is_stopped: + if batch.is_prefill: + input_ids.extend(request_context.output_ids) + else: + input_ids.append(request_context.output_ids[-1]) + request_kv_caches.append(request_context.request_kv_cache) + if not batch.is_prefill: + request_context.request_kv_cache.increment() - if hasattr(batch, "requests") and batch.requests: - ids = self.add_request(batch) - - if not self.reqctx: - return None, batch, (0, 0) - - reqs = sorted( - self.reqctx.items(), - key=lambda req: (not req[1].is_prefill(), req[1].lora_id), - ) - - input_ids = [] - lora_ids, lora_lens = [], [] - batchKvCache = self.modelKvCache.getOrCreate(batch.batch_id) - prefill_reqIds = [] - decode_reqIds = [] - - for requestId, req in reqs: - req.prefill = req.is_prefill() - if req.prefill: - input_ids.extend(req.output_ids) - prefill_reqIds.append(requestId) - batchKvCache.create(requestId, req.prompt_len) - else: - input_ids.append(req.output_ids[-1]) - decode_reqIds.append(requestId) - batchKvCache.get(requestId).increment() - if lora_ids and lora_ids[-1] == req.lora_id: - lora_lens[-1] += 1 - else: - lora_ids.append(req.lora_id) - lora_lens.append(1) + if lora_ids and lora_ids[-1] == request_context.lora_id: + lora_lens[-1] += 1 + else: + lora_ids.append(request_context.lora_id) + lora_lens.append(1) - input_ids = torch.tensor( + input_ids_tensor = torch.tensor( input_ids, dtype=torch.long, device=self.device, ) - prefillBatchPosition = batchKvCache.getKvCacheBatchPosition( - prefill_reqIds, isPrefill=True + request_kv_caches_prefill = request_kv_caches if batch.is_prefill else [] + request_kv_caches_decode = [] if batch.is_prefill else request_kv_caches + prefillBatchPosition: KvCacheBatchPosition = getKvCacheBatchPosition( + request_kv_caches_prefill, isPrefill=True, device=self.device ) - decodeBatchPosition = batchKvCache.getKvCacheBatchPosition( - decode_reqIds, isPrefill=False + decodeBatchPosition: KvCacheBatchPosition = getKvCacheBatchPosition( + request_kv_caches_decode, isPrefill=False, device=self.device ) - # Forward pass raw_logits, _ = self.model( - input_ids, - self.modelKvCache.kvCachePool, + input_ids_tensor, + self.kvCachePool, prefillBatchPosition, decodeBatchPosition, self.loraManager.get_lora_batched_weights(lora_ids, lora_lens), ) start_decode = time.time_ns() - prefill_logits = ( raw_logits[prefillBatchPosition.seq_indptr[1:] - 1] if prefillBatchPosition.total_seq_len > 0 @@ -440,50 +393,58 @@ def generate_token( all_stop = True generations: List[Generation] = [] - for i, (reqid, reqctx) in enumerate(reqs): - next_token_id = reqctx.get_next_token_id(logits[i].unsqueeze(0)) - reqctx.append_token(next_token_id) + num_stopped_requests = 0 + for i, request_context in enumerate(batch.request_contexts): + if request_context.is_stopped: + num_stopped_requests += 1 + continue + next_token_id = request_context.get_next_token_id( + logits[i - num_stopped_requests].unsqueeze(0) + ) + request_context.append_token(next_token_id) # text = reqctx.decode_tokens() # todo: ?? # special handling for ChatGLM - if 'ChatGLM' in str(type(self.model)): + if "ChatGLM" in str(type(self.model)): text = self.tokenizer.decode( [next_token_id], clean_up_tokenization_spaces=False, - skip_special_tokens=False + skip_special_tokens=False, ) else: text = self.tokenizer.decode( next_token_id, clean_up_tokenization_spaces=False, - skip_special_tokens=False - ) + skip_special_tokens=False, + ) - is_stop = reqctx.is_stop() - if is_stop != None: + stop_reason = request_context.get_stop_reason() + if stop_reason != None: output_text = self.tokenizer.decode( - reqctx.output_ids[reqctx.prompt_len :], + request_context.output_ids[request_context.prompt_len :], clean_up_tokenization_spaces=False, skip_special_tokens=False, ) generated_text = GeneratedText( output_text, - len(reqctx.output_ids) - reqctx.prompt_len + 1, - is_stop, + len(request_context.output_ids) - request_context.prompt_len + 1, + stop_reason, None, ) - self.reqctx.pop(reqid) - batchKvCache.release(reqid) + request_context.is_stopped = True + request_context.request_kv_cache.release() else: generated_text = None all_stop = False # Prefill - if reqctx.prefill: # and reqctx.prefill_logprobs: + if batch.is_prefill: # and request_context.prefill_logprobs: # Remove generated token to only have prefill and add nan for first prompt token prefill_logprobs = [] # todo - prefill_token_ids = reqctx.output_ids[: reqctx.prompt_len] + prefill_token_ids = request_context.output_ids[ + : request_context.prompt_len + ] # special handling for ChatGLM - if 'ChatGLM' in str(type(self.model)): + if "ChatGLM" in str(type(self.model)): prefill_texts = self.tokenizer.batch_decode( [prefill_token_ids], clean_up_tokenization_spaces=False, @@ -495,20 +456,19 @@ def generate_token( clean_up_tokenization_spaces=False, skip_special_tokens=False, ) - - reqctx.prefill_tokens = Tokens( + request_context.prefill_tokens = Tokens( prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[], ) - reqctx.prefix_offset = reqctx.prompt_len + request_context.prefix_offset = request_context.prompt_len else: - reqctx.prefill_tokens = None + request_context.prefill_tokens = None generation = Generation( - reqid, - reqctx.prefill_tokens, + request_context.request_id, + request_context.prefill_tokens, Tokens( [next_token_id], [0], # prob @@ -525,5 +485,6 @@ def generate_token( decode_ns = time.time_ns() - start_decode # The router stops generation only when batch=None if all_stop: - batch = None - return generations, batch, (forward_ns, decode_ns) + return generations, None, (forward_ns, decode_ns) + else: + return generations, batch, (forward_ns, decode_ns) diff --git a/server/text_generation_server/server_flashinfer.py b/server/text_generation_server/server_flashinfer.py index b8127639..fff4c881 100644 --- a/server/text_generation_server/server_flashinfer.py +++ b/server/text_generation_server/server_flashinfer.py @@ -13,7 +13,8 @@ from text_generation_server.cache import Cache from text_generation_server.interceptor import ExceptionInterceptor -from text_generation_server.models_flashinfer import Model, get_model +from text_generation_server.models_flashinfer import get_model +from text_generation_server.models_flashinfer.flashinfer_causal_lm import FlashinferLM from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor @@ -34,7 +35,7 @@ def exit_gracefully(self, signum, frame): class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def __init__( self, - model: Model, + model: FlashinferLM, cache: Cache, quantize: Optional[str], server_urls: List[str], @@ -60,40 +61,24 @@ async def ServiceDiscovery(self, request, context): return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls) async def ClearCache(self, request, context): - if request.HasField("id"): - self.cache.delete(request.id) - else: - self.cache.clear() + self.model.clear_cache() return generate_pb2.ClearCacheResponse() - async def FilterBatch(self, request, context): - batch = self.cache.pop(request.batch_id) - if batch is None: - raise ValueError(f"Batch ID {request.batch_id} not found in cache.") - filtered_batch = batch.filter(request.request_ids) - self.cache.set(filtered_batch) + # async def FilterBatch(self, request, context): + # batch = self.cache.pop(request.batch_id) + # if batch is None: + # raise ValueError(f"Batch ID {request.batch_id} not found in cache.") + # filtered_batch = batch.filter(request.request_ids) + # self.cache.set(filtered_batch) - return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) + # return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): + pass - batch = self.model.batch_type.from_pb( - request.batch, self.model.tokenizer, self.model.dtype, self.model.device - ) - max_supported_total_tokens = self.model.warmup(batch) - - return generate_pb2.WarmupResponse( - max_supported_total_tokens=max_supported_total_tokens - ) - - async def Prefill(self, request, context): + async def Prefill(self, request: generate_pb2.PrefillRequest): start = time.time_ns() - batch = self.model.batch_type.from_pb( - request.batch, self.model.tokenizer, self.model.dtype, self.model.device - ) - - generations, next_batch, timings = self.model.generate_token(batch) - self.cache.set(next_batch) + generations, next_batch, timings = self.model.prefill_batch(request.batch) return generate_pb2.PrefillResponse( generations=[generation.to_pb() for generation in generations], batch=next_batch.to_pb() if next_batch else None, @@ -102,32 +87,11 @@ async def Prefill(self, request, context): total_ns=time.time_ns() - start, ) - async def Decode(self, request, context): + async def Decode(self, request: generate_pb2.DecodeRequest): start = time.time_ns() - if len(request.batches) == 0: - raise ValueError("Must provide at least one batch") - - batches = [] - for batch_pb in request.batches: - batch = self.cache.pop(batch_pb.id) - if batch is None: - raise ValueError(f"Batch ID {batch_pb.id} not found in cache.") - batches.append(batch) - - if len(batches) == 0: - raise ValueError("All batches are empty") - - if len(batches) > 1: - start_concat = time.time_ns() - batch = self.model.batch_type.concatenate(batches) - concat_ns = time.time_ns() - start_concat - else: - batch = batches[0] - concat_ns = None - - generations, next_batch, timings = self.model.generate_token(batch) - self.cache.set(next_batch) - + generations, next_batch, timings, concat_ns = self.model.decode_batch( + request.batches + ) return generate_pb2.DecodeResponse( generations=[generation.to_pb() for generation in generations], batch=next_batch.to_pb() if next_batch else None, diff --git a/server/text_generation_server/utils/cache_manager_flashinfer.py b/server/text_generation_server/utils/cache_manager_flashinfer.py index bd80c07d..7d34edcf 100644 --- a/server/text_generation_server/utils/cache_manager_flashinfer.py +++ b/server/text_generation_server/utils/cache_manager_flashinfer.py @@ -41,7 +41,7 @@ def __init__( self.device = device self.max_pages = max_pages self.page_len = page_len - self.free_page_mask = torch.ones(max_pages, dtype=torch.bool, device=device) + self.free_page_mask = torch.ones(max_pages, dtype=torch.bool, device="cpu") def allocate(self, num_pages: int): free_page_indices = self.free_page_mask.nonzero() @@ -65,6 +65,7 @@ def __init__(self, kvCachePool: KvCachePool, page_len: int, seq_init_len: int): self.kv_last_page_len = seq_init_len - (init_num_pages - 1) * self.page_len self.kv_page_indices = kvCachePool.allocate(init_num_pages) self.kv_len = seq_init_len + self.is_released = False def increment(self): self.kv_len += 1 @@ -76,92 +77,48 @@ def increment(self): def release(self): self.kvCachePool.deallocate(self.kv_page_indices) - - -class BatchKvCache: - def __init__(self, kvCachePool: KvCachePool, page_len, device): - self.kvCachePool = kvCachePool - self.page_len = page_len - self.device = device - self.kvCacheDict: dict[int, RequestKvCache] = {} - - def get(self, req_id): - return self.kvCacheDict.get(req_id) - - def create(self, req_id, seq_init_len): - self.kvCacheDict[req_id] = RequestKvCache( - self.kvCachePool, self.page_len, seq_init_len - ) - return self.kvCacheDict[req_id] - - def release(self, req_id): - self.kvCacheDict[req_id].release() - del self.kvCacheDict[req_id] - - def increment(self): - for kvCache in self.kvCacheDict.values(): - kvCache.increment() - - def setRequestOrder(self, requestIds: List[int]): - self.requestIds = requestIds - - def getKvCacheBatchPosition(self, requestIds: List[int], isPrefill: bool): - kv_page_indices_list = [] - kv_page_indptr_list = [] - seq_indptr_list = [] - kv_last_page_len_list = [] - seq_lens_list = [] - cum_pages = 0 - cum_seq_len = 0 - for requestId in requestIds: - kvCache = self.kvCacheDict[requestId] - kv_page_indices_list.extend(kvCache.kv_page_indices) - kv_page_indptr_list.append(cum_pages) - seq_indptr_list.append(cum_seq_len) - kv_last_page_len_list.append(kvCache.kv_last_page_len) - seq_lens_list.append(kvCache.kv_len) - cum_pages += len(kvCache.kv_page_indices) - cum_seq_len += kvCache.kv_len if isPrefill else 1 - + self.is_released = True + + +def getKvCacheBatchPosition( + request_kv_caches: List[RequestKvCache], isPrefill: bool, device: torch.device +) -> KvCacheBatchPosition: + kv_page_indices_list = [] + kv_page_indptr_list = [] + seq_indptr_list = [] + kv_last_page_len_list = [] + seq_lens_list = [] + cum_pages = 0 + cum_seq_len = 0 + for request_kv_cache in request_kv_caches: + kv_page_indices_list.extend(request_kv_cache.kv_page_indices) kv_page_indptr_list.append(cum_pages) seq_indptr_list.append(cum_seq_len) - kv_page_indices = torch.tensor( - kv_page_indices_list, dtype=torch.int32, device=self.device - ) - kv_page_indptr = torch.tensor( - kv_page_indptr_list, dtype=torch.int32, device=self.device - ) - kv_last_page_len = torch.tensor( - kv_last_page_len_list, dtype=torch.int32, device=self.device - ) - seq_indptr = torch.tensor( - seq_indptr_list, dtype=torch.int32, device=self.device - ) - seq_lens = torch.tensor( - seq_lens_list, - dtype=torch.int32, - device=self.device, - ) - return KvCacheBatchPosition( - seq_indptr=seq_indptr, - kv_page_indptr=kv_page_indptr, - kv_page_indices=kv_page_indices, - kv_last_page_len=kv_last_page_len, - seq_lens=seq_lens, - total_seq_len=cum_seq_len, - ) - - -class ModelKvCache: - def __init__(self, kvCachePool: KvCachePool): - self.kvCachePool = kvCachePool - self.device = kvCachePool.device - self.page_len = kvCachePool.page_len - self.batchKvCacheDict: dict[int, BatchKvCache] = {} - - def getOrCreate(self, batch_id): - batchKvCache = self.batchKvCacheDict.get(batch_id) or BatchKvCache( - self.kvCachePool, self.page_len, self.device - ) - self.batchKvCacheDict[batch_id] = batchKvCache - return batchKvCache + kv_last_page_len_list.append(request_kv_cache.kv_last_page_len) + seq_lens_list.append(request_kv_cache.kv_len) + cum_pages += len(request_kv_cache.kv_page_indices) + cum_seq_len += request_kv_cache.kv_len if isPrefill else 1 + + kv_page_indptr_list.append(cum_pages) + seq_indptr_list.append(cum_seq_len) + kv_page_indices = torch.tensor( + kv_page_indices_list, dtype=torch.int32, device=device + ) + kv_page_indptr = torch.tensor(kv_page_indptr_list, dtype=torch.int32, device=device) + kv_last_page_len = torch.tensor( + kv_last_page_len_list, dtype=torch.int32, device=device + ) + seq_indptr = torch.tensor(seq_indptr_list, dtype=torch.int32, device=device) + seq_lens = torch.tensor( + seq_lens_list, + dtype=torch.int32, + device=device, + ) + return KvCacheBatchPosition( + seq_indptr=seq_indptr, + kv_page_indptr=kv_page_indptr, + kv_page_indices=kv_page_indices, + kv_last_page_len=kv_last_page_len, + seq_lens=seq_lens, + total_seq_len=cum_seq_len, + )