diff --git a/examples/metrics-monitoring/README.md b/examples/metrics-monitoring/README.md index 64ef1160c66b..62150f3a10b7 100644 --- a/examples/metrics-monitoring/README.md +++ b/examples/metrics-monitoring/README.md @@ -2,3 +2,40 @@ ## Continuous Batching Metrics in Transformers +To setup metric monitoring with continuous batching, you will want to have tempo and prometheus running. + +For this, we provide a docker compose image in `examples/metrics-monitoring`. + +To run it: + +```sh +cd examples/metrics-monitoring +docker compose up +``` + +Then, in your srcipt running CB, you will need to create a MeterProvider and TracerProvider as follows: + +```py +from opentelemetry import metrics, trace +from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor + +resource = Resource.create({"service.name": "transformers"}) + +metrics_exporter = PeriodicExportingMetricReader( + OTLPMetricExporter(endpoint="http://localhost:9090/api/v1/otlp/v1/metrics"), # Uses OTEL_EXPORTER_OTLP_METRICS_ENDPOINT env var + export_interval_millis=1000 +) +meter_provider = MeterProvider(resource=resource, metric_readers=[metrics_exporter]) +metrics.set_meter_provider(meter_provider) + +trace_exporter = OTLPSpanExporter(endpoint="http://localhost:4318/v1/traces") # Uses OTEL_EXPORTER_OTLP_TRACES_ENDPOINT env var +tracer_provider = TracerProvider(resource=resource) +tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter)) +trace.set_tracer_provider(tracer_provider) +``` diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index 821e3d9a271b..b5ad94ed3f11 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -1,4 +1,22 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os import time +from typing import Optional import datasets import torch @@ -7,108 +25,247 @@ from transformers.generation import GenerationConfig -torch.set_float32_matmul_precision("high") +MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507" -model_id = "meta-llama/Llama-3.2-3b-Instruct" -model = ( - AutoModelForCausalLM.from_pretrained( - model_id, - attn_implementation="paged_attention|kernels-community/flash-attn", - dtype=torch.bfloat16, + +def generate_simple( + attn_implementation: str, simple_batch_inputs: list[int], generation_config: GenerationConfig +) -> list[str]: + attn_implementation = { + "sdpa_paged": "sdpa", + "eager_paged": "eager", + "flash_paged": "flash_attention_2", + }[attn_implementation] + + model = ( + AutoModelForCausalLM.from_pretrained( + MODEL_ID, + torch_dtype=torch.bfloat16, + attn_implementation=attn_implementation, + ) + .cuda() + .eval() ) - .eval() - .cuda() -) -tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") - -generation_config = GenerationConfig( - max_new_tokens=512, - # use_cuda_graph=False, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - do_sample=False, -) - -train_dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test") -train_dataset = train_dataset.select(range(500)) # Use only 5 examples for the simple version -print("--- Running CB Generation Example ---") - - -def tokenize_function(examples): - return tokenizer(examples["question"]) - - -tokenized_datasets = train_dataset.map(tokenize_function, batched=True) -simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets] - -start_time_simple = time.time() -model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs") -batch_outputs = model.generate_batch( - inputs=simple_batch_inputs, - generation_config=generation_config, -) -end_time_simple = time.time() -token_count = 0 -for request in batch_outputs: - input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False) + + decoded_outputs = [] + for input_ids in simple_batch_inputs: + input_ids = torch.tensor([input_ids]).to("cuda") + attention_mask = torch.ones_like(input_ids) + outputs = model.generate(input_ids, attention_mask=attention_mask, generation_config=generation_config) + generated_tokens = outputs[0][input_ids.shape[1] :] + decoded_output = tokenizer.decode(generated_tokens, skip_special_tokens=True) + decoded_outputs.append(decoded_output) + + return decoded_outputs + + +def setup_metrics(): try: - output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=False) - token_count += len(batch_outputs[request].generated_tokens[1:]) + from opentelemetry import metrics, trace + from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter + from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + from opentelemetry.sdk.metrics import MeterProvider + from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader + from opentelemetry.sdk.resources import Resource + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import BatchSpanProcessor + + resource = Resource.create({"service.name": "transformers"}) + metrics_exporter = PeriodicExportingMetricReader( + OTLPMetricExporter( + endpoint="http://localhost:9090/api/v1/otlp/v1/metrics" + ), # Uses OTEL_EXPORTER_OTLP_METRICS_ENDPOINT env var + export_interval_millis=1000, + ) + meter_provider = MeterProvider(resource=resource, metric_readers=[metrics_exporter]) + metrics.set_meter_provider(meter_provider) + trace_exporter = OTLPSpanExporter( + endpoint="http://localhost:4318/v1/traces" + ) # Uses OTEL_EXPORTER_OTLP_TRACES_ENDPOINT env var + tracer_provider = TracerProvider(resource=resource) + tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter)) + trace.set_tracer_provider(tracer_provider) except Exception as e: - print(f"Decoding failed for request {request}: {e}") - token_count += len(batch_outputs[request].generated_tokens[1:]) - output_text = tokenizer.decode(batch_outputs[request].generated_tokens[1:], skip_special_tokens=False) - if len(output_text) > 0: + print(f"Error setting up metrics: {e}") + + +def batch_generate( + model: AutoModelForCausalLM, + simple_batch_inputs: list, + generation_config: GenerationConfig, + tokenizer: AutoTokenizer, + displayed_samples: int = 0, # -1: no display, 0: display stats, >0: display inputs and some outputs + output_file: Optional[str] = None, + expected_outputs: Optional[list[str]] = None, + slice_inputs: bool = True, +) -> tuple[float, float]: + # Actual batch generation + if displayed_samples >= 0: + print("--- Running CB Generation Example ---") + start_time_simple = time.time() + batch_outputs = model.generate_batch( + inputs=simple_batch_inputs, + generation_config=generation_config, + slice_inputs=slice_inputs, # TODO: move this to the generation config + ) + end_time_simple = time.time() + if displayed_samples >= 0: + print("Done with batch generation.") + + # Decode outputs + token_count = 0 + data = [] + for i, request in enumerate(batch_outputs): + input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=True) + data.append({"input": input_text}) + + # Try to decode the output + try: + output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=True) + token_count += len(batch_outputs[request].generated_tokens[1:]) + data[-1]["output"] = output_text + except Exception as e: + print(f"Decoding failed for request {request}: {e}") + data[-1]["output"] = "__ERROR__" + continue + + # Display sample if asked + if i < displayed_samples: + if len(output_text) > 0: + print("-" * 20) + print(f"{request} Input: {input_text}") + print(f"{request} Output: {output_text}") + else: + print(f"{request} Input: {input_text}") + print("[WARN]") + print(f"{request} Output was empty!") + + # Compare with classic generate if asked + if expected_outputs is not None: + matches = output_text == expected_outputs[i] + data[-1]["ref"] = expected_outputs[i] + data[-1]["matches"] = matches + print(f"Request {i} matches" if matches else f"Request {i} does NOT match!") + + # Compute stats and maybe print them + gen_time = end_time_simple - start_time_simple + tok_per_sec = token_count / gen_time + if displayed_samples >= 0: print("-" * 20) - print(f"{request} Input: {input_text}") - print(f"{request} Output: {output_text}") - else: - print("", end="\r\r\r\r") -print("-" * 20) -print("--- Finished CB Generation Example ---\n\n") + print("--- Finished CB Generation Example ---\n") + print(f"CB generation took: {gen_time:.2f} seconds for {token_count} tokens. {tok_per_sec:.2f}tok/s") + stats = { + "num_blocks": generation_config.num_blocks, + "max_batch_tokens": generation_config.max_batch_tokens, + "gen_time": gen_time, + "token_count": token_count, + "tok_per_sec": tok_per_sec, + } + # If an output file is provided, save the reordered data to it + data.sort(key=lambda x: x["input"]) + data = [stats] + data + if output_file is not None: + with open(output_file, "w") as f: + json.dump(data, f, indent=4) -print( - f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds for {token_count} tokens. {token_count / (end_time_simple - start_time_simple)}tok/s" -) + return gen_time, tok_per_sec -# train_dataset = train_dataset.select(range(5)) # Use only 5 examples for the simple version +if __name__ == "__main__": + # Parse args + parser = argparse.ArgumentParser() + parser.add_argument("--num-blocks", "-n", type=int, default=None) + parser.add_argument("--max-batch-tokens", "-b", type=int, default=None) -# tokenized_test_prompts = tokenizer(_TEST_PROMPTS, padding=True, padding_side="left", truncation=True, max_length=512) -# simple_batch_inputs = list(tokenized_test_prompts["input_ids"]) + parser.add_argument( + "--attn", type=str, default="paged_attention|kernels-community/flash-attn", help="Attention implementation" + ) + parser.add_argument("--matmul-precision", "-mp", type=str, default="high") # set to "none" to disable + parser.add_argument("--slice-inputs", action="store_true", default=False) + parser.add_argument("--use-cuda-graph", action="store_true", default=False) + parser.add_argument("--compile", action="store_true", default=False) -# def tokenize_function(examples): -# # Truncate to avoid overly long prompts exceeding max context length -# return tokenizer(examples["question"], padding=True, truncation=True, max_length=512) + parser.add_argument("--samples", type=int, default=500) + parser.add_argument("--displayed", type=int, default=0, help="Number of samples to display") + parser.add_argument("--output-file", type=str, default=None) + parser.add_argument("--compare", action="store_true", default=False) + parser.add_argument("--metrics", action="store_true", default=False) + args = parser.parse_args() + # If turned on, we setup metrics + if args.metrics: + setup_metrics() -# tokenized_datasets = train_dataset.map(tokenize_function, batched=True) -# simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets] + # Set matmul precision if not none + if args.matmul_precision != "none": + torch.set_float32_matmul_precision(args.matmul_precision) + # Prepare model + model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, + attn_implementation=args.attn, + dtype=torch.bfloat16, + ) + model = model.cuda().eval() -# model.config.attn_implementation = "sdpa" -# start_time_simple = time.time() -# batch_size = 64 -# full_outputs = [] -# from tqdm import tqdm + # If turned on, we compile the model + if args.compile: + model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs") + if args.slice_inputs: + assert not args.compile, "Slicing inputs requires is not the model to be compiled" + assert not args.use_cuda_graph, "Slicing inputs is not compatible with cuda graphs" -# for i in tqdm(range(0, len(simple_batch_inputs)-batch_size, batch_size)): -# outputs = model.generate( -# torch.tensor(simple_batch_inputs[i:i+batch_size], device=model.device), -# generation_config=GenerationConfig( -# max_new_tokens=16, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id -# ), -# ) -# full_outputs.extend(outputs.tolist()) + # Prepare tokenizer and dataset + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left") + dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test") + dataset = dataset.select(range(args.samples)) # Use only 5 examples for the simple version + tokenized_datasets = dataset.map(lambda x: tokenizer(x["question"]), batched=True) + simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets] -# end_time_simple = time.time() -# print(f"\nSimple batch generation took: {end_time_simple - start_time_simple:.2f} seconds") + # Prepare generation config + generation_config = GenerationConfig( + max_new_tokens=512, + use_cuda_graph=args.use_cuda_graph, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + do_sample=False, + num_blocks=args.num_blocks, + max_batch_tokens=args.max_batch_tokens, + ) + + # If we need to compare, we need to generate the reference outputs + expected_outputs = generate_simple(args.attn, simple_batch_inputs, generation_config) if args.compare else None + + # If no output file is provided, we pick a name based on the args + if args.output_file is None: + os.makedirs("runs/cb", exist_ok=True) + attn = args.attn.replace("|", "_").replace("/", "_") + args.output_file = ( + f"runs/cb/{args.num_blocks}_{args.max_batch_tokens}_{attn}_{args.matmul_precision}_{args.samples}.json" + ) + + # Run warmup batch generation + batch_generate( + model, + simple_batch_inputs[: min(5, args.samples)], + generation_config, + tokenizer, + displayed_samples=-1, + slice_inputs=args.slice_inputs, + ) + + # Run batch generation + gen_time, tok_per_sec = batch_generate( + model, + simple_batch_inputs, + generation_config, + tokenizer, + displayed_samples=args.displayed, + output_file=args.output_file, + expected_outputs=expected_outputs, + slice_inputs=args.slice_inputs, + ) -# print("\nResults from simple generate_batch:") -# for i, request in enumerate(full_outputs): -# output_text = tokenizer.decode(request, skip_special_tokens=False) -# print("-" * 20) -# print(f" Output: {output_text}") -# print("-" * 20) -# print("--- Finished Simple Batch Generation Example ---\n\n") +# Example usage: +# python examples/pytorch/continuous_batching.py --num-blocks 369 --max-batch-tokens 23 --attn sdpa_paged -mp none --samples 1 --displayed 0 --output-file sliced.json diff --git a/examples/pytorch/continuous_batching_simple.py b/examples/pytorch/continuous_batching_simple.py new file mode 100644 index 000000000000..3ae5e3d83870 --- /dev/null +++ b/examples/pytorch/continuous_batching_simple.py @@ -0,0 +1,110 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import time + +import datasets +import torch + +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.generation import GenerationConfig + + +MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507" +DISPLAYED_SAMPLES = 3 + + +if __name__ == "__main__": + # Parse args + parser = argparse.ArgumentParser() + parser.add_argument("--num-blocks", "-n", type=int, default=None) + parser.add_argument("--max-batch-tokens", "-b", type=int, default=None) + parser.add_argument( + "--attn", type=str, default="paged_attention|kernels-community/flash-attn", help="Attention implementation" + ) + parser.add_argument("--samples", type=int, default=500) + args = parser.parse_args() + + # Prepare model + model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, + attn_implementation=args.attn, + dtype=torch.bfloat16, + ) + model = model.cuda().eval() + + # Prepare tokenizer and dataset + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left") + dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test") + dataset = dataset.select(range(args.samples)) + tokenized_datasets = dataset.map(lambda x: tokenizer(x["question"]), batched=True) + simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets] + + # Prepare generation config + generation_config = GenerationConfig( + max_new_tokens=512, + use_cuda_graph=False, # Not supported for simple version + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + do_sample=False, + num_blocks=args.num_blocks, + max_batch_tokens=args.max_batch_tokens, + ) + + # Warmup iterations + _ = model.generate_batch( + inputs=simple_batch_inputs[: min(5, args.samples)], + generation_config=generation_config, + slice_inputs=True, + ) + + # Actual batch generation + print("--- Running CB Generation Example ---") + start_time = time.time() + batch_outputs = model.generate_batch( + inputs=simple_batch_inputs, + generation_config=generation_config, + slice_inputs=True, + ) + end_time = time.time() + print("Done with batch generation.") + + # Decode outputs + token_count = 0 + for i, request in enumerate(batch_outputs): + input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=True) + # Try to decode the output + try: + output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=True) + token_count += len(batch_outputs[request].generated_tokens[1:]) + except Exception as e: + print(f"Decoding failed for request {request}: {e}") + continue + + # Display sample if asked + if i < DISPLAYED_SAMPLES: + print("-" * 20) + print(f"{request} Input: {input_text}") + if len(output_text) > 0: + print(f"{request} Output: {output_text}") + else: + print(f"[WARN] {request} Output was empty!") + + # Compute stats and maybe print them + gen_time = end_time - start_time + tok_per_sec = token_count / gen_time + print("-" * 20) + print("--- Finished CB Generation Example ---\n") + print(f"CB generation took: {gen_time:.2f} seconds for {token_count} tokens. {tok_per_sec:.2f}tok/s") diff --git a/setup.py b/setup.py index 3b67610db313..79bb0f9ef0d8 100644 --- a/setup.py +++ b/setup.py @@ -445,7 +445,7 @@ def run(self): extras["benchmark"] = deps_list("optimum-benchmark") # OpenTelemetry dependencies for metrics collection in continuous batching -extras["open-telemetry"] = deps_list("opentelemetry-api") +extras["open-telemetry"] = deps_list("opentelemetry-api") + ["opentelemetry-exporter-otlp", "opentelemetry-sdk"] # when modifying the following list, make sure to update src/transformers/dependency_versions_check.py install_requires = [ diff --git a/src/transformers/generation/continuous_batching.py b/src/transformers/generation/continuous_batching.py deleted file mode 100644 index 2d903da05b62..000000000000 --- a/src/transformers/generation/continuous_batching.py +++ /dev/null @@ -1,1459 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team. -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import queue -import threading -import time -from abc import ABC, abstractmethod -from collections import deque -from dataclasses import dataclass, field -from enum import Enum -from functools import partial -from typing import Optional, Union - -import torch -import torch.nn as nn -from tokenizers.decoders import DecodeStream -from tqdm import tqdm - -from ..configuration_utils import PretrainedConfig -from ..generation.configuration_utils import GenerationConfig -from ..tokenization_utils_fast import PreTrainedTokenizerFast -from ..utils.logging import logging -from ..utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced - - -class RequestStatus(Enum): - """Status of a generation request through its lifecycle.""" - - PENDING = "pending" - PREFILLING = "prefilling" - PREFILLING_SPLIT = "prefilling_split" - SPLIT_PENDING_REMAINDER = "split_pending_remainder" - DECODING = "decoding" - FINISHED = "finished" - FAILED = "failed" - - -logger = logging.getLogger(__name__) - - -@dataclass -class GenerationOutput: - """Tracks the output of a generation request. - - Attributes: - request_id (str): The ID of the generation request. - prompt_ids (list[int]): The IDs of the prompt tokens. - generated_tokens (list[int]): The generated tokens. - logprobs (list[float]): The log probabilities of the generated tokens. - error (Optional[str]): Any error message associated with the request. When None, the request was successful. - """ - - request_id: str - prompt_ids: list[int] = field(default_factory=list) - generated_tokens: list[int] = field(default_factory=list) - logprobs: list[float] = field(default_factory=list) - error: Optional[str] = None - status: RequestStatus = RequestStatus.PENDING - created_time: float = field(default_factory=time.time) - next_token: Optional[int] = field(default_factory=int) - - -@dataclass -class RequestState: - """Tracks the state of a generation request through its lifecycle. - - Attributes: - status (RequestStatus): can be one of PENDING, PREFILLING, PREFILLING_SPLIT, - SPLIT_PENDING_REMAINDER, DECODING, FINISHED, FAILED - """ - - # Required fields - request_id: str - prompt_ids: Optional[list[int]] = None # the one being processed - full_prompt_ids: Optional[list[int]] = None # the full prompt - remaining_prompt_ids: list[int] = field(default_factory=list) # For split requests - static_outputs: list[int] = field(default_factory=list) - allocated_blocks: list[int] = field(default_factory=list) - position_offset: int = 0 # Current position in the sequence for position_ids - status: RequestStatus = RequestStatus.PENDING - max_new_tokens: int = 20 - eos_token_id: int = -1 - created_time: float = field(default_factory=time.time) - error: Optional[str] = None - next_token: Optional[str] = None - - def current_len(self) -> int: - """Get the current length of the sequence (prompt + generated tokens).""" - return self.position_offset - - def generated_len(self) -> int: - """Get the number of tokens generated so far.""" - return len(self.static_outputs) - - @traced - def update_with_token(self, token_id: int) -> bool: - """Update the request with a newly generated token and check for completion. - - Args: - token_id: The token ID to add to the output sequence - - Returns: - bool: True if the request is now complete, False otherwise - """ - # Only update if we're in decoding state - if self.status != RequestStatus.DECODING: - return False - - is_eos = token_id == self.eos_token_id and self.eos_token_id != -1 - is_max_len = self.generated_len() >= self.max_new_tokens - - # Only add the token if we're not finishing due to max length - # (EOS tokens should still be added to the output) - if not (is_max_len and not is_eos): - self.static_outputs.extend([token_id]) - - if is_eos or is_max_len: - self.status = RequestStatus.FINISHED - return True - return False - - def __repr__(self): - return f"RequestState(\n\trequest_id={self.request_id},\n\tstatus={self.status},\n\tout_tokens={self.generated_len()},\n\tquery_length={len(self.prompt_ids)}, \n\tremaining_tokens={len(self.remaining_prompt_ids)}, \n\tkv_length={self.position_offset}\n\tfull_prompt_lenght={len(self.full_prompt_ids)},\n\tallocated_blocks={self.allocated_blocks},\n\tgenerated_tokens={self.static_outputs}\n)" - - def to_generation_output(self): - """Convert the request state to a GenerationOutput object.""" - return GenerationOutput( - request_id=self.request_id, - prompt_ids=self.full_prompt_ids, - status=self.status, - generated_tokens=self.static_outputs, - logprobs=[], - error=self.error, - next_token=self.next_token, - ) - - -@attach_tracer() -class PagedAttentionCache: - def __init__( - self, - config: PretrainedConfig, - generation_config: GenerationConfig, - device: torch.device, - dtype: torch.dtype = torch.float16, - num_requests: int = 100, - layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, - tp_size: Optional[int] = None, - ) -> None: - """Initialize a paged attention cache for efficient memory usage. - - Args: - config: Model configuration - generation_config: Generation configuration containing cache parameters - device: Device for the cache tensors - dtype: Data type for the cache tensors - layer_device_map: Optional mapping of layer indices to devices - initial_prompt_shapes: Optional sample prompts to help calculate optimal cache size - """ - # Extract model dimensions - self.num_key_value_heads = ( - config.num_attention_heads - if getattr(config, "num_key_value_heads", None) is None - else config.num_key_value_heads - ) - num_key_value_heads = self.num_key_value_heads - if tp_size is not None and tp_size > 1: - if num_key_value_heads % tp_size != 0: - raise ValueError( - f"Number of key value heads {num_key_value_heads} must be divisible by tensor parallel size {tp_size}." - ) - # If the model is using tensor parallelism, we need to adjust the number of heads accordingly. - # self.num_key_value_heads //= tp_size - - self.head_dim = ( - config.head_dim - if hasattr(config, "head_dim") and config.head_dim is not None - else config.hidden_size // config.num_attention_heads - ) - self.num_hidden_layers = config.num_hidden_layers - - # Calculate optimal block size and number if not provided - num_blocks = getattr(generation_config, "num_blocks", 1024) - block_size = getattr(generation_config, "block_size", 32) - max_memory_percent = getattr(generation_config, "max_memory", 0.9) - max_batch_tokens = getattr(generation_config, "max_batch_tokens", 256) - if num_blocks is None or max_batch_tokens is None: - num_blocks, max_batch_tokens = compute_optimal_blocks( - generation_config.max_new_tokens, - block_size=block_size, - head_dim=self.head_dim, - num_layers=self.num_hidden_layers, - num_heads=self.num_key_value_heads, - max_memory_percent=max_memory_percent, - dtype=dtype, - num_blocks=num_blocks, - ) - logger.warning( - f"Using calculated num_blocks={num_blocks}, block_size={block_size}, max concurrent requests {max_batch_tokens}" - ) - self.max_batch_tokens = max_batch_tokens - self.block_size = block_size - self.num_blocks = num_blocks - self.cache_shape = (num_key_value_heads, num_blocks, self.block_size, self.head_dim) - - self.dtype = dtype - self.device = device - - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - for idx in range(config.num_hidden_layers): - layer_device = layer_device_map[idx] if layer_device_map is not None else device - new_layer_key_cache = torch.zeros(self.cache_shape, dtype=self.dtype, device=layer_device) - new_layer_value_cache = torch.zeros(self.cache_shape, dtype=self.dtype, device=layer_device) - # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, - # preventing compiled graph breaks when updating the cache. - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) - - # Block management data structures - self._free_blocks = deque(range(num_blocks)) - self._block_tables: dict[str, list[int]] = {} - - @traced - def allocate_blocks(self, n_blocks: int, request_id: str) -> list[int]: - """Allocates n_blocks for a given request_id.""" - if len(self._free_blocks) < n_blocks: - return False - - allocated = [] - for _ in range(n_blocks): - allocated.append(self._free_blocks.popleft()) - - if request_id not in self._block_tables: - self._block_tables[request_id] = [] - self._block_tables[request_id].extend(allocated) - return allocated - - @traced - def free_blocks(self, request_id: str) -> None: - """Frees all blocks associated with a request_id.""" - if request_id in self._block_tables: - blocks_to_free = self._block_tables.pop(request_id) - self._free_blocks.extend(blocks_to_free) - else: - logger.info(f"Attempted to free blocks for non-existent request_id: {request_id}") - - def get_num_free_blocks(self) -> int: - """Returns the number of free blocks available.""" - return len(self._free_blocks) - - def get_block_table(self, request_id: str) -> list[int]: - """Returns the block table for a request.""" - return self._block_tables.get(request_id, []) - - @traced - def _get_physical_indices(self, state: RequestState, logical_indices: list[int]) -> list[int]: - """ - Maps logical sequence indices to physical cache indices using the block table, using PyTorch. - - Args: - request_id: The request ID. - logical_indices: A list of logical indices. - - Returns: - A list of physical indices. - - Raises: - ValueError: If no block table is found for the request ID. - IndexError: If a logical index maps to a block index that is out of bounds. - """ - request_id = state.request_id - block_table = self._block_tables.get(request_id) - if not block_table: - raise ValueError(f"No block table found for request {request_id}") - - block_size = self.block_size - physical_indices = [] - - for idx in logical_indices: - block_idx = idx // block_size - block_offset = idx % block_size - - if block_idx >= len(block_table): - raise IndexError( - f"Logical index {idx} maps to block index {block_idx} which is out of bounds " - f"for request {request_id}" - ) - - physical_block_num = block_table[block_idx] - physical_index = physical_block_num * block_size + block_offset - physical_indices.append(physical_index) - - return physical_indices - - @traced - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - read_index, - write_index, - **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor]: - # Reshape cache for easier indexing - total_slots = self.num_blocks * self.block_size - k_cache_flat = self.key_cache[layer_idx].view(self.num_key_value_heads, total_slots, self.head_dim) - v_cache_flat = self.value_cache[layer_idx].view(self.num_key_value_heads, total_slots, self.head_dim) - k_cache_flat[:, write_index, :] = key_states[0] - v_cache_flat[:, write_index, :] = value_states[0] - return k_cache_flat[None, :, read_index, :], v_cache_flat[None, :, read_index, :] - - -class Scheduler(ABC): - """ - Abstract base class for scheduling requests in the continuous batch processor. - It is expected that cache allocation and scheduling logic will be implemented in subclasses. - """ - - def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = False): - self.active_requests: dict[str, RequestState] = {} - self.waiting_requests: dict[str, RequestState] = {} - self.waiting_requests_order: deque[str] = deque() - self.cache = cache - self.retain_cache_on_finish = retain_cache_on_finish - - @abstractmethod - def add_waiting_request(self, state: RequestState): - """Add a request to the waiting list.""" - pass - - @abstractmethod - def schedule_batch(self, token_budget: int) -> list[RequestState]: - pass - - @traced - def has_pending_requests(self) -> bool: - """Check if there are requests ready to be processed.""" - return len(self.active_requests) or len(self.waiting_requests) - - @abstractmethod - def finish_request(self, request_id: str, evict_from_cache: bool = True): - """Finish processing a request and free its allocated blocks.""" - pass - - @traced - def get_active_request_static_outputs(self, request_id: str) -> list[int]: - if request_id in self.active_requests: - return self.active_requests[request_id].static_outputs - return [] - - -@attach_tracer() -class FIFOScheduler(Scheduler): - @traced - def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int): - # 1. we check that the occupancy is less than the requested length - # 2. we allocate enough blocks to cover the requested length - current_len = state.current_len() - occupancy = len(state.allocated_blocks) * self.cache.block_size - current_len - if occupancy < len_next_tokens or (len(state.allocated_blocks) == 0): - blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1 - allocated = self.cache.allocate_blocks(blocks_needed, state.request_id) - if not allocated: - return False - state.allocated_blocks.extend(allocated) - return True - - @traced(span_name="prepare_request") - def _prepare_request_for_processing( - self, state: RequestState, token_budget: int, request_ids_to_remove_from_waiting: set[str] - ): - """Prepare a request for processing in the current batch.""" - request_tokens = ( - state.remaining_prompt_ids if state.status == RequestStatus.SPLIT_PENDING_REMAINDER else state.prompt_ids - ) - if len(request_tokens) < token_budget: - # Can process the entire prompt/remainder - if state.status == RequestStatus.PENDING: - self.active_requests[state.request_id] = state - state.status = RequestStatus.PREFILLING - request_ids_to_remove_from_waiting.add(state.request_id) - elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER: - state.status = RequestStatus.PREFILLING - state.prompt_ids = state.remaining_prompt_ids - state.remaining_prompt_ids = [] - else: - # Need to split the request - if state.status == RequestStatus.PENDING: - self.active_requests[state.request_id] = state - state.status = RequestStatus.PREFILLING_SPLIT - request_ids_to_remove_from_waiting.add(state.request_id) - elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER: - state.status = RequestStatus.PREFILLING_SPLIT - state.remaining_prompt_ids = request_tokens[token_budget:] - state.prompt_ids = request_tokens[:token_budget] - - @traced - def add_waiting_request(self, state: RequestState): - """Add a request to the waiting list.""" - if self.retain_cache_on_finish and state.request_id in self.active_requests: - old_state = self.active_requests.pop(state.request_id) - state.prompt_ids = state.prompt_ids[len(old_state.full_prompt_ids) :] - state.allocated_blocks = old_state.allocated_blocks - state.position_offset = old_state.position_offset - self.waiting_requests[state.request_id] = state - self.waiting_requests_order.append(state.request_id) - - @traced - def schedule_batch(self, token_budget: int) -> list[RequestState]: - priority_states: list[RequestState] = [] - second_priority_states: list[RequestState] = [] - scheduled_requests = [] - - for state in self.active_requests.values(): - if state.status == RequestStatus.DECODING: - priority_states.append(state) - if state.status == RequestStatus.SPLIT_PENDING_REMAINDER: - second_priority_states.append(state) - - # Add waiting requests to second priority - for req_id in self.waiting_requests_order: - second_priority_states.append(self.waiting_requests[req_id]) - - candidates = priority_states + second_priority_states - request_ids_to_remove_from_waiting = set() - - for state in candidates: - self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting) - request_len = len(state.prompt_ids) - if not self._allocate_blocks_if_needed( - state, len(state.prompt_ids) - ): # don't schedule if we can't allocate blocks - if len(self.cache._free_blocks) == 0: - break - continue - - @traced - def _add_to_scheduled_requests(state: RequestState): - scheduled_requests.append(state) - - _add_to_scheduled_requests(state) - - token_budget -= request_len - - @traced - def _remove_from_waiting_requests(state: RequestState): - req_id = state.request_id - if req_id in self.waiting_requests: - del self.waiting_requests[req_id] - request_ids_to_remove_from_waiting.add(req_id) - - _remove_from_waiting_requests(state) - - if token_budget == 0: - break - - self.waiting_requests_order = deque( - [req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting] - ) - - return scheduled_requests - - @traced - def finish_request(self, request_id: str, evict_from_cache: bool = True): - if evict_from_cache: - self.cache.free_blocks(request_id) - if request_id in self.active_requests: - del self.active_requests[request_id] - - -@attach_tracer() -class PrefillFirstScheduler(Scheduler): - @traced - def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int): - # 1. we check that the occupancy is less than the requested length - # 2. we allocate enough blocks to cover the requested length - current_len = state.current_len() - occupancy = len(state.allocated_blocks) * self.cache.block_size - current_len - if occupancy < len_next_tokens or (len(state.allocated_blocks) == 0): - blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1 - allocated = self.cache.allocate_blocks(blocks_needed, state.request_id) - if not allocated: - return False - state.allocated_blocks.extend(allocated) - return True - - @traced(span_name="prepare_request") - def _prepare_request_for_processing( - self, state: RequestState, token_budget: int, request_ids_to_remove_from_waiting: set[str] - ): - """Prepare a request for processing in the current batch.""" - request_tokens = ( - state.remaining_prompt_ids if state.status == RequestStatus.SPLIT_PENDING_REMAINDER else state.prompt_ids - ) - if len(request_tokens) < token_budget: - # Can process the entire prompt/remainder - if state.status == RequestStatus.PENDING: - self.active_requests[state.request_id] = state - state.status = RequestStatus.PREFILLING - request_ids_to_remove_from_waiting.add(state.request_id) - elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER: - state.status = RequestStatus.PREFILLING - state.prompt_ids = state.remaining_prompt_ids - state.remaining_prompt_ids = [] - else: - # Need to split the request - if state.status == RequestStatus.PENDING: - self.active_requests[state.request_id] = state - state.status = RequestStatus.PREFILLING_SPLIT - request_ids_to_remove_from_waiting.add(state.request_id) - elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER: - state.status = RequestStatus.PREFILLING_SPLIT - state.remaining_prompt_ids = request_tokens[token_budget:] - state.prompt_ids = request_tokens[:token_budget] - - @traced - def add_waiting_request(self, state: RequestState): - """Add a request to the waiting list.""" - if self.retain_cache_on_finish and state.request_id in self.active_requests: - old_state = self.active_requests.pop(state.request_id) - state.prompt_ids = state.prompt_ids[len(old_state.full_prompt_ids) :] # XXX: check for indexing error? - state.allocated_blocks = old_state.allocated_blocks - state.position_offset = old_state.position_offset - self.waiting_requests[state.request_id] = state - self.waiting_requests_order.append(state.request_id) - - @traced - def schedule_batch(self, token_budget: int) -> list[RequestState]: - priority_states: list[RequestState] = [] - second_priority_states: list[RequestState] = [] - scheduled_requests = [] - - for state in self.active_requests.values(): - if state.status == RequestStatus.SPLIT_PENDING_REMAINDER: - priority_states.append(state) - elif state.status == RequestStatus.DECODING: - second_priority_states.append(state) - - for req_id in self.waiting_requests_order: - second_priority_states.append(self.waiting_requests[req_id]) - - candidates = priority_states + second_priority_states - - request_ids_to_remove_from_waiting = set() - - for state in candidates: - self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting) - request_len = len(state.prompt_ids) - if not self._allocate_blocks_if_needed( - state, len(state.prompt_ids) - ): # don't schedule if we can't allocate blocks - if len(self.cache._free_blocks) == 0: - break - continue - - @traced - def _add_to_scheduled_requests(state: RequestState): - scheduled_requests.append(state) - - _add_to_scheduled_requests(state) - - token_budget -= request_len - - @traced - def _remove_from_waiting_requests(state: RequestState): - req_id = state.request_id - if req_id in self.waiting_requests: - del self.waiting_requests[req_id] - request_ids_to_remove_from_waiting.add(req_id) - - _remove_from_waiting_requests(state) - - if token_budget == 0: - break - - self.waiting_requests_order = deque( - [req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting] - ) - - return scheduled_requests - - @traced - def finish_request(self, request_id: str, evict_from_cache: bool = True): - if evict_from_cache: - self.cache.free_blocks(request_id) - if request_id in self.active_requests: - del self.active_requests[request_id] - - -def get_device_and_memory(): - # Select best available device - if torch.cuda.is_available(): - device = torch.device("cuda") - total_memory = torch.cuda.get_device_properties(device).total_memory - reserved_memory = torch.cuda.memory_reserved(device) - allocated_memory = torch.cuda.memory_allocated(device) - - elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): - device = torch.device("mps") - # MPS memory reporting (PyTorch 2.0+) - total_memory = torch.mps.driver_allocated_memory() - allocated_memory = total_memory - torch.mps.recommended_max_memory() - reserved_memory = 0 # MPS does not track reserved separately - - else: - device = torch.device("cpu") - total_memory = None - reserved_memory = 0 - allocated_memory = 0 - - return device, total_memory, reserved_memory, allocated_memory - - -@traced(standalone=True) -def compute_optimal_blocks( - max_num_tokens, - block_size, - head_dim, - num_heads, - num_layers, - max_memory_percent=0.9, - num_blocks=None, - dtype=torch.float16, -): - device, total, reserved, allocated = get_device_and_memory() - available_memory = int((total - max(allocated, reserved)) * max_memory_percent) - - dtype_size = torch.tensor([], dtype=dtype).element_size() - bytes_per_token = 2 * num_heads * head_dim * dtype_size * num_layers - if num_blocks is not None: - # TODO - max_possible_concurrent_requests = num_blocks * bytes_per_token - # FIXME: forgot to add the inintial prompt length in the mix.... - max_possible_concurrent_requests = int( - available_memory // (bytes_per_token * max_num_tokens * max_num_tokens // 4) - ) - if max_possible_concurrent_requests <= 0: - logger.warning("you are trying to generate a bit too many tokens") - max_possible_concurrent_requests = 32 - max_concurrent_tokens = min(64, max_possible_concurrent_requests) - # FIXME: Optimal means uses all memory - optimal_num_blocks = max(((max_concurrent_tokens * max_num_tokens) // block_size) + 1, 64) - return optimal_num_blocks, max_concurrent_tokens - - -@dataclass -class PagedAttentionArgs: - input_ids: torch.Tensor - attention_mask: torch.Tensor - position_ids: torch.Tensor - cumulative_seqlens_q: torch.Tensor - cumulative_seqlens_k: torch.Tensor - max_seqlen_q: int - max_seqlen_k: int - write_index: torch.Tensor - read_index: torch.Tensor - logits_indices: torch.Tensor - block_tables: dict[str, list[int]] - cache: PagedAttentionCache - use_cache: bool = False - - -@traced -def create_document_mask(cumulative_seqlens_q, cumulative_seqlens_k): - # Number of documents - valid_docs_q = cumulative_seqlens_q[1:] > cumulative_seqlens_q[:-1] - valid_docs_k = cumulative_seqlens_k[1:] > cumulative_seqlens_k[:-1] - num_valid_docs = min(valid_docs_q.sum(), valid_docs_k.sum()) - - # Trim to valid docs - cumulative_seqlens_q = cumulative_seqlens_q[: num_valid_docs + 1] - cumulative_seqlens_k = cumulative_seqlens_k[: num_valid_docs + 1] - - total_q = cumulative_seqlens_q[-1] - total_k = cumulative_seqlens_k[-1] - - q_indices = torch.arange(total_q, device=cumulative_seqlens_q.device) - k_indices = torch.arange(total_k, device=cumulative_seqlens_k.device) - - q_doc_ids = torch.bucketize(q_indices, cumulative_seqlens_q[1:], right=True) - k_doc_ids = torch.bucketize(k_indices, cumulative_seqlens_k[1:], right=False) - doc_mask = q_doc_ids[:, None] == k_doc_ids[None, :] - # apply causal mask where no decoding (same nb of q than k) - - is_causal = ~(cumulative_seqlens_q[1:] - cumulative_seqlens_q[:-1] == 1) * cumulative_seqlens_q[1:] - apply_causal = torch.bucketize(q_indices, is_causal, right=True)[:, None] == k_doc_ids - # TODO don't apply on prefill splitting - causal_mask = torch.triu(torch.ones(total_q, total_k, device=q_doc_ids.device), diagonal=1).bool() - doc_mask.masked_fill_((apply_causal & causal_mask), False) - return doc_mask - - -# Continuous Batch Processor (Internal Logic) -@attach_tracer() -class ContinuousBatchProcessor: - def __init__( - self, - cache: PagedAttentionCache, - config: PretrainedConfig, - generation_config: GenerationConfig, - input_queue: queue.Queue, - output_queue: queue.Queue, - stop_event: threading.Event, - model_device: torch.device, - model_dtype: torch.dtype, - scheduler: Scheduler, - streaming: bool = False, - manual_eviction: bool = False, - ): - """Initialize the continuous batch processor. - - Args: - cache: The paged attention cache to use - generation_config: The generation configuration - input_queue: Queue for incoming requests - output_queue: Queue for outgoing results - stop_event: Event to signal processing should stop - model_device: Device for model inputs/outputs - model_dtype: Data type for model inputs/outputs - streaming: Whether to stream tokens as they're generated - """ - self.cache = cache - self.config = config - self.generation_config = generation_config - self.input_queue = input_queue - self.output_queue = output_queue - self.stop_event = stop_event - self.model_device = model_device - self.model_dtype = model_dtype - self.scheduler = scheduler - self.streaming = streaming - self.manual_eviction = manual_eviction - - self.requests_in_batch: list[RequestState] = [] - - # Set up metrics collector - self.max_batch_tokens = cache.max_batch_tokens - self.metrics = ContinuousBatchProcessorMetrics(cache.max_batch_tokens) - - self.setup_static_tensors() - - self.tokenizer = PreTrainedTokenizerFast.from_pretrained(self.config._name_or_path) - self.decode_stream = DecodeStream(skip_special_tokens=True) - - @traced(standalone=True) - def setup_static_tensors(self): - T = self.max_batch_tokens - max_token_budget = self.cache.num_blocks * self.cache.block_size - tensor_metadata = {"dtype": torch.int32, "device": self.model_device} - self.tensor_metadata = tensor_metadata - self.input_ids = torch.zeros((1, T), **tensor_metadata) - self.position_ids = torch.zeros((1, T), **tensor_metadata) - self.attention_mask = torch.zeros( - (1, 1, T, max_token_budget), dtype=self.model_dtype, device=self.model_device - ) - self.cumulative_seqlens_q = torch.zeros((T + 1,), **tensor_metadata) - self.cumulative_seqlens_k = torch.zeros((T + 1,), **tensor_metadata) - self.write_index = torch.zeros((T,), **tensor_metadata) - self.read_index = torch.zeros((max_token_budget,), **tensor_metadata) - self.logits_indices = torch.full((T,), -1, **tensor_metadata) - self.max_seqlen_q = 0 - self.max_seqlen_k = 0 - self.output_ids = torch.full((1, T), -1, **tensor_metadata) - - @traced - @torch.no_grad() - def reset_static_tensors(self): - """Reset static tensors for the next batch.""" - self.input_ids.zero_() - self.position_ids.zero_() - self.attention_mask.fill_(torch.finfo(self.model_dtype).min) - self.cumulative_seqlens_q.zero_() - self.cumulative_seqlens_k.zero_() - self.write_index.fill_(-1) - self.read_index.fill_(-1) - self.logits_indices.fill_(-1) - self.max_seqlen_q = 0 - self.max_seqlen_k = 0 - self.output_ids.zero_() - - def get_model_kwargs(self) -> PagedAttentionArgs: - """Get model keyword arguments for the current batch.""" - # torch.set_printoptions(threshold=100000,linewidth=10000) - return { - "input_ids": self.input_ids, - "position_ids": self.position_ids, - "attention_mask": self.attention_mask, - "cu_seq_lens_q": self.cumulative_seqlens_q, - "cu_seq_lens_k": self.cumulative_seqlens_k, - "write_index": self.write_index, - "read_index": self.read_index, - "logits_indices": self.logits_indices, - "max_seqlen_q": self.max_seqlen_q, - "max_seqlen_k": self.max_seqlen_k, - "block_tables": self.cache._block_tables, - "cache": self.cache, - "use_cache": False, - } - - def __repr__(self): - return ( - f"ContinuousBatchProcessor(input_queue={self.input_queue}, output_queue={self.output_queue}, active_requests={self.scheduler.active_requests}, waiting_requests={self.scheduler.waiting_requests})" - + self.get_model_kwargs().__repr__() - ) - - @traced - def _get_new_requests(self): - """Pull new requests from the input queue and add to waiting list.""" - while not self.input_queue.empty(): - try: - state = self.input_queue.get_nowait() - if state is None: # Sentinel value - continue - self.scheduler.add_waiting_request(state) - - except queue.Empty: - break - except Exception as e: - logger.error(f"Error processing new request: {e}", exc_info=True) - state: RequestState = locals().get("state") - if state is not None: - self._handle_request_error(e, state) - - @traced - def _handle_request_error(self, error, state: RequestState): - """Handle general request processing error.""" - state.status = RequestStatus.FAILED - state.error = str(error) - - # Include any generated tokens if this is an active request - if isinstance(state.request_id, str): - state.static_outputs = self.scheduler.get_active_request_static_outputs(state.request_id) - else: - state.static_outputs = [] - - self.metrics.record_request_completion(state.created_time, state.request_id) - self.output_queue.put(state.to_generation_output()) - - @traced - def prepare_next_batch(self): - """Prepare tensors and metadata for the next model forward pass.""" - # Get new requests from the queue - self._get_new_requests() - if not self.scheduler.has_pending_requests(): - return None - - self.metrics.record_queue_metrics(len(self.scheduler.active_requests), len(self.scheduler.waiting_requests)) - - self.requests_in_batch = self.scheduler.schedule_batch(self.max_batch_tokens) - if not self.requests_in_batch: - return None - - # Get the request objects for this batch - self.reset_static_tensors() - position_ids = [] - input_ids = [] - read_index = [] - write_index = [] - cumulative_seqlens_q = [0] - cumulative_seqlens_k = [0] - logits_indices = [] - self.metrics.record_batch_metrics(self.requests_in_batch) - - for state in self.requests_in_batch: - next_input_ids = state.prompt_ids - input_ids.extend(next_input_ids) - past_length = state.position_offset - query_length = len(next_input_ids) - key_length = query_length + past_length - cache_index = list(range(key_length)) - - positions_to_add = cache_index[past_length:] - read_indices = self.cache._get_physical_indices(state, cache_index) - write_indices = read_indices[-query_length:] - - position_ids.extend(positions_to_add) - read_index.extend(read_indices) - write_index.extend(write_indices) - cumulative_seqlens_q.append(cumulative_seqlens_q[-1] + query_length) - cumulative_seqlens_k.append(cumulative_seqlens_k[-1] + key_length) - if len(state.remaining_prompt_ids) == 0: - logits_indices.append(cumulative_seqlens_q[-1] - 1) - self.max_seqlen_q = max(self.max_seqlen_q, query_length) - self.max_seqlen_k = max(self.max_seqlen_k, key_length) - state.position_offset += query_length - - logger.info( - f"Scheduled: {len(self.requests_in_batch)}, Waiting: {len(self.scheduler.waiting_requests)}, Active: {len(self.scheduler.active_requests)}. cum Q: {cumulative_seqlens_q[-1]}. cum KV: {cumulative_seqlens_k[-1]}, free blocks: {self.cache.get_num_free_blocks()}" - ) - self._build_tensors( - input_ids, - position_ids, - read_index, - write_index, - cumulative_seqlens_q, - cumulative_seqlens_k, - logits_indices, - ) - - self.metrics.record_kv_cache_memory_metrics(self.cache) - - @traced - def _build_tensors( - self, - input_ids, - position_ids, - read_index, - write_index, - cumulative_seqlens_q, - cumulative_seqlens_k, - logits_indices, - ): - to_tensor = partial(torch.tensor, **self.tensor_metadata) - self.input_ids[:, : len(input_ids)] = to_tensor(input_ids) - self.position_ids[:, : len(position_ids)] = to_tensor(position_ids) - self.write_index[: len(write_index)] = to_tensor(write_index) - self.read_index[: len(read_index)] = to_tensor(read_index) - self.cumulative_seqlens_q[: len(cumulative_seqlens_q)] = to_tensor(cumulative_seqlens_q) - self.cumulative_seqlens_k[: len(cumulative_seqlens_k)] = to_tensor(cumulative_seqlens_k) - self.logits_indices[: len(logits_indices)] = to_tensor(logits_indices) - min_value = torch.finfo(self.model_dtype).min - if self.config._attn_implementation != "paged_attention": # we set `is_causal` to True in paged call` - for i in range(len(cumulative_seqlens_q) - 1): - if ( - cumulative_seqlens_q[i + 1] - cumulative_seqlens_q[i] - < cumulative_seqlens_k[i + 1] - cumulative_seqlens_k[i] - and cumulative_seqlens_q[i + 1] - cumulative_seqlens_q[i] >= 1 - ): - diagonal = ( - cumulative_seqlens_k[i + 1] - (cumulative_seqlens_q[i + 1] - cumulative_seqlens_q[i]) + 1 - ) - diagonal = diagonal - cumulative_seqlens_k[i] - else: - diagonal = 1 - query_range = slice(cumulative_seqlens_q[i], cumulative_seqlens_q[i + 1]) - key_range = slice(cumulative_seqlens_k[i], cumulative_seqlens_k[i + 1]) - - mask = torch.triu( - torch.full( - self.attention_mask[..., query_range, key_range].shape, - min_value, - dtype=self.model_dtype, - device=self.model_device, - ), - diagonal=diagonal, - ) - self.attention_mask[..., query_range, key_range] = mask - - @traced - def _sync(self): - if self.output_ids is not None: - try: - out = self.output_ids.tolist()[0] # should be the only synch we do - except Exception: - out = [0, 1] - else: - out = [0, 0] - return out - - @traced - def _maybe_send_output(self, state: RequestState, token: int): - """Send output to the queue based on streaming mode and request state.""" - if self.streaming: - state.next_token = self.decode_stream.step(self.tokenizer, state.static_outputs[-1]) - self.output_queue.put(state.to_generation_output()) - elif state.status == RequestStatus.FINISHED: - self.output_queue.put(state.to_generation_output()) - - @traced - def update_batch(self): - """Update request states based on generated tokens.""" - out_tokens = self._sync() - finished_request_ids = [] - for i, state in enumerate(self.requests_in_batch): - req_id = state.request_id - if len(state.remaining_prompt_ids) == 0: - self.metrics.record_ttft_metric(state.created_time, state.request_id) - state.status = RequestStatus.DECODING - token = out_tokens[self.logits_indices[i]] - state.prompt_ids = [token] - if state.update_with_token(token): - self.metrics.record_request_completion(state.created_time, state.request_id) - self.scheduler.finish_request(state.request_id, evict_from_cache=(not self.manual_eviction)) - finished_request_ids.append(req_id) - self._maybe_send_output(state, token) - elif state.status == RequestStatus.PREFILLING_SPLIT: - state.status = RequestStatus.SPLIT_PENDING_REMAINDER - if self.cache.get_num_free_blocks() == 0: - raise ValueError("No more free blocks") - - @traced - def has_pending_requests(self) -> bool: - """Check if there are any active or waiting requests.""" - return self.scheduler.has_pending_requests() - - @traced - def handle_batch_error(self, error): - """Handle errors during batch processing.""" - failed_reqs = self.requests_in_batch - for req in failed_reqs: - self._handle_request_error(error, req) - self.scheduler.finish_request(req.request_id) - - @traced - def fail_all_requests(self, error): - """Fail all active requests with the given error. - - Args: - error: The error to report in the failure message - """ - - requests = list(self.scheduler.active_requests.values()) - for state in requests: - self._handle_request_error(error, state) - self.scheduler.finish_request(state.request_id) - - # Also fail any requests in the waiting queue - for req_id in list(self.scheduler.waiting_requests.keys()): - state = self.scheduler.waiting_requests.pop(req_id) - self._handle_request_error(error, state) - - # Clear the ordering queue - self.scheduler.waiting_requests_order.clear() - - -SCHEDULER_MAPPING = { - "fifo": FIFOScheduler, - "prefill_first": PrefillFirstScheduler, -} - - -# Manager Class (User Interface) -@attach_tracer() -class ContinuousBatchingManager: - """Manager for handling continuous batching of generation requests. - - This class provides the user interface for submitting generation requests, - retrieving results, and managing the background generation thread. - """ - - def __init__( - self, - model, - generation_config: GenerationConfig, - manual_eviction: bool = False, - max_queue_size=0, - streaming: bool = True, - ): - """Initialize the continuous batching manager. - - Args: - model: The language model for generation - generation_config: Configuration for generation parameters - max_queue_size: Maximum size of the request queue (0 = unlimited) - streaming: Whether to stream tokens as they are generated - """ - self.model = model.eval() - generation_config = model.generation_config if generation_config is None else generation_config - self.generation_config = generation_config - self.input_queue = queue.Queue(maxsize=max_queue_size) - self.output_queue = queue.Queue() - self.stop_event = threading.Event() - self.streaming = streaming - self.log_prob_generation = getattr(generation_config, "log_prob_generation", False) - self._generation_thread = None - self._request_counter = 0 - self._request_lock = threading.Lock() - self.model.generation_config.top_p = None - self.do_sample = getattr(generation_config, "do_sample", True) - self.logit_processor = self.model._get_logits_processor(generation_config) - self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", True) - self.profile = getattr(generation_config, "profile", False) - self.manual_eviction = manual_eviction - self.batch_processor: Optional[ContinuousBatchProcessor] = None - self.decode_stream = DecodeStream(skip_special_tokens=True) - - @traced - def start(self): - """Start the background generation thread.""" - if self._generation_thread is not None and self._generation_thread.is_alive(): - logger.warning("Manager thread is already running.") - return - - self._result_queue = queue.Queue() - self._generation_thread = threading.Thread(target=self._run_generation_loop) - self._generation_thread.start() - logger.info("Continuous batching manager started.") - - def is_running(self): - """Check if the background generation thread is running.""" - return self._generation_thread is not None and self._generation_thread.is_alive() - - def stop(self, block: bool = False, timeout: Optional[float] = None): - """Signal the background thread to stop. - - Args: - block: Whether to wait for the thread to stop - timeout: Maximum time to wait for the thread to stop - """ - if self._generation_thread is None: - logger.warning("Manager not started.") - return - - if not self.stop_event.is_set(): - self.stop_event.set() - logger.info("Stopping continuous batching manager...") - - if block: - self.join(timeout) - - def join(self, timeout: Optional[float] = None): - """Wait for the background thread to finish. - - Args: - timeout: Maximum time to wait for the thread to stop - """ - if self._generation_thread is not None: - self._generation_thread.join(timeout=timeout) - if self._generation_thread.is_alive(): - logger.warning("Generation thread did not exit after join timeout.") - else: - logger.info("Continuous Batching Manager stopped.") - self._generation_thread = None - - def add_request( - self, input_ids: list[int], request_id: Optional[str] = None, max_new_tokens: Optional[int] = None - ) -> str: - """Add a new generation request to the queue. - - Args: - input_ids: Input token IDs to use as prompt - request_id: Optional custom request ID (auto-generated if None) - **kwargs: Additional generation parameters - - Returns: - str: The request ID - """ - if request_id is None: - with self._request_lock: - request_id = f"req_{self._request_counter}" - self._request_counter += 1 - - max_new_tokens = self.generation_config.max_new_tokens if max_new_tokens is None else max_new_tokens - - state = RequestState( - request_id=request_id, - prompt_ids=list(input_ids), - full_prompt_ids=list(input_ids), - max_new_tokens=max_new_tokens, - eos_token_id=self.generation_config.eos_token_id, - ) - - # Use block=True with timeout to handle backpressure if queue is full - self.input_queue.put(state, block=True, timeout=10) # XXX: pass timeout as fn arg? - logger.debug(f"Added request {request_id} to queue.") - return request_id - - def add_requests(self, inputs: list[list[int]], **kwargs): - for i, input_ids in enumerate(inputs): - # Assign a predictable request ID for ordering results later - req_id = f"batch_req_{i}" - self.add_request(input_ids, request_id=req_id, **kwargs) - - def get_result(self, timeout=None) -> Optional[GenerationOutput]: - """Retrieve one result from the output queue. - - Args: - timeout: Maximum time to wait for a result - - Returns: - Optional[Dict]: The result data or None if timeout - """ - if self._generation_thread is None and self.output_queue.empty(): - return None - try: - result = self.output_queue.get(block=True, timeout=timeout) - logger.debug(f"Retrieved result for request {result.request_id}") - return result - except queue.Empty: - return None - - def __iter__(self): - """Iterate over results as they become available.""" - while ( - self._generation_thread is not None and self._generation_thread.is_alive() or not self.output_queue.empty() - ): - result = self.get_result(timeout=0.1) # allow the model to run for 10 seconds - if result is not None: - yield result - - @traced - def warmup(self, batch_processor): - stream = torch.cuda.Stream(device=self.model.device) - stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(stream): - # Warmup the model with a dummy forward pass - self._generation_step(batch_processor) - torch.cuda.current_stream().wait_stream(stream) - - self.graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(self.graph, stream=stream): - self._generation_step(batch_processor) - - @traced - # @torch.compile - def _generation_step(self, batch_processor: ContinuousBatchProcessor): - """Perform a single generation step. This is cuda graphed""" - batch_data = batch_processor.get_model_kwargs() - with torch.no_grad(): - logits = self._model_forward(batch_data) - if self.log_prob_generation: - batch_processor.output_probs.copy_(logits) # TODO - probs = self._process_logit(batch_data, logits) - self._sample(batch_processor, probs) - - @traced(span_name="model_forward") - def _model_forward(self, batch_data): - return self.model(**batch_data).logits - - @traced(span_name="logit_processing") - def _process_logit(self, batch_data, logits): - # Pass continuous batching context to logits processor if it supports it. TODO we should find a way to make this a little bit cleaner! - if hasattr(self.logit_processor, "set_continuous_batching_context"): - self.logit_processor.set_continuous_batching_context( - batch_data["logits_indices"], batch_data["cu_seq_lens_q"] - ) - return self.logit_processor(batch_data["input_ids"], logits) - - @traced(span_name="sampling") - def _sample(self, batch_processor: ContinuousBatchProcessor, probs): - if self.do_sample: # sample - probs = nn.functional.softmax(probs, dim=-1) - next_tokens = torch.multinomial(probs[0], num_samples=1).squeeze(1) - else: - next_tokens = torch.argmax(probs, dim=-1) - batch_processor.output_ids.copy_(next_tokens) - - def _run_generation_loop(self): - """Main processing loop running in the background thread.""" - batch_processor = None - try: - paged_attention_cache = PagedAttentionCache( - self.model.config, - self.generation_config, - self.model.device, - self.model.dtype, - num_requests=len(self.input_queue.queue), - tp_size=getattr(self.model, "_tp_size", None), # Use model's actual TP setting - ) - - scheduler = None - if hasattr(self.generation_config, "scheduler"): - scheduler = SCHEDULER_MAPPING.get(self.generation_config.scheduler) - if scheduler is None: - logger.warning(f"Scheduler '{scheduler}' not found. Defaulting to FIFO.") - scheduler = FIFOScheduler - else: - # Default to fifo - scheduler = FIFOScheduler - - batch_processor = ContinuousBatchProcessor( - paged_attention_cache, - self.model.config, - self.generation_config, - self.input_queue, - self.output_queue, - self.stop_event, - self.model.device, - self.model.dtype, - scheduler(paged_attention_cache, self.manual_eviction), - self.streaming, - self.manual_eviction, - ) - self.batch_processor = batch_processor - is_first = True - while (not self.stop_event.is_set()) or batch_processor.has_pending_requests(): - self._inner_generation_loop(batch_processor, is_first) - if is_first: - is_first = False - - except Exception as e: - logger.error(f"Error in generation loop: {e}", exc_info=True) - self._handle_critical_error(e, batch_processor) - finally: - logger.info("Generation loop finished.") - - @traced(span_name="generation_loop") - def _inner_generation_loop(self, batch_processor: ContinuousBatchProcessor, is_first: bool = False): - if torch.cuda.is_available(): - torch.cuda.synchronize() - batch_processor.prepare_next_batch() - device, total, reserved, allocated = get_device_and_memory() - logger.info(f"[Memory] Device: {device}, Total: {total}, Reserved: {reserved}, Allocated: {allocated}") - if torch.cuda.is_available() and self.use_cuda_graph: - if is_first: - self.warmup(batch_processor) - elif hasattr(self, "graph"): - try: - self._graph_replay() - except Exception as e: - logger.error(f"Model forward pass failed: {e}", exc_info=True) - batch_processor.handle_batch_error(e) - return - else: - self._generation_step(batch_processor) - else: - self._generation_step(batch_processor) - if torch.cuda.is_available(): - torch.cuda.synchronize() - batch_processor.update_batch() - - @traced(span_name="graph_replay") - def _graph_replay(self): - self.graph.replay() - - @traced - def _handle_critical_error(self, error, batch_processor: Optional[ContinuousBatchProcessor]): - """Handle critical errors that terminate the generation loop.""" - # Signal stop - self.stop_event.set() - - # Fail pending requests in input queue - try: - while True: - req_data = self.input_queue.get_nowait() - if batch_processor is not None: - batch_processor._handle_request_error(error, req_data) - except queue.Empty: - pass - - # Fail active requests - if batch_processor is not None: - batch_processor.fail_all_requests(error) - - @traced - def evict_request_from_cache(self, request_id: str): - """Evict a request from the cache. It is assumed that the request is already finished.""" - if not self.manual_eviction: - raise RuntimeError("Manual eviction is not enabled for this manager.") - if self.batch_processor is not None: - self.batch_processor.scheduler.finish_request(request_id) - - -class ContinuousMixin: - """Mixin class for models to add continuous batching capabilities.""" - - def init_continuous_batching( - self, - generation_config: Optional[GenerationConfig] = None, - manual_eviction: bool = False, - max_queue_size: int = 0, - streaming: bool = False, - ) -> ContinuousBatchingManager: - """Initialize a manager for continuous batching inference. - - Args: - generation_config: Custom generation configuration - max_queue_size: Maximum size of the input request queue - streaming: Whether to stream tokens as they are generated - - Returns: - `ContinuousBatchingManager`: The manager instance to add requests and retrieve results. - """ - if not hasattr(self, "config") or not hasattr(self, "device") or not hasattr(self, "dtype"): - raise AttributeError("Model must have 'config', 'device', and 'dtype' attributes.") - - gen_config = generation_config if generation_config is not None else self.generation_config - if gen_config is None: - raise ValueError("A GenerationConfig must be provided or set in the model.") - - if gen_config.eos_token_id is None: - logger.warning("`eos_token_id` not set in GenerationConfig. Setting to -1 (disabled).") - gen_config.eos_token_id = -1 - - # Create and return the manager - return ContinuousBatchingManager( - model=self, - generation_config=gen_config, - manual_eviction=manual_eviction, - max_queue_size=max_queue_size, - streaming=streaming, - ) - - @traced - @torch.inference_mode() - def generate_batch( - self, - inputs: list[list[int]], - generation_config: Optional[GenerationConfig] = None, - progress_bar: bool = True, - **kwargs, - ) -> list[list[int]]: - """Generate sequences for a batch of prompts using continuous batching. - - Args: - inputs: List of input token sequences (prompts) - generation_config: Optional generation configuration - **kwargs: Additional generation parameters - - Returns: - `list[list[int]]`: A list containing the generated sequences (including prompt tokens - if not handled otherwise) for each input prompt, in the same order. - Returns an empty list `[]` for requests that failed. - """ - if not inputs: - return [] - - # Initialize manager with the batch inputs - manager = self.init_continuous_batching(generation_config=generation_config) - manager.start() - results = {} - num_requests = len(inputs) - try: - from tqdm.contrib.logging import logging_redirect_tqdm - - with logging_redirect_tqdm([logger]): - with tqdm( - total=num_requests, - disable=(not progress_bar), - desc=f"Solving {num_requests} requests", - unit="request", - ) as pbar: - manager.add_requests(inputs, **kwargs) - finished_count = 0 - while finished_count < num_requests: - result = manager.get_result(timeout=1) - if result: - req_id = result.request_id - if result.status == RequestStatus.FINISHED: - results[req_id] = result - finished_count += 1 - pbar.update(1) - logger.debug(manager.batch_processor.tokenizer.decode(result.generated_tokens)) - else: - if not manager.is_running(): - logger.error("Generation thread terminated unexpectedly.") - break - - except Exception as e: - logger.error(f"Error during batch generation: {e}", exc_info=True) - finally: - manager.stop(block=True, timeout=5.0) - return results diff --git a/src/transformers/generation/continuous_batching/__init__.py b/src/transformers/generation/continuous_batching/__init__.py new file mode 100644 index 000000000000..11d15b6468e2 --- /dev/null +++ b/src/transformers/generation/continuous_batching/__init__.py @@ -0,0 +1,20 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .cache import PagedAttentionCache +from .classes import RequestState, RequestStatus +from .continuous_api import ContinuousBatchingManager, ContinuousMixin + + +__all__ = ["PagedAttentionCache", "RequestState", "RequestStatus", "ContinuousMixin", "ContinuousBatchingManager"] diff --git a/src/transformers/generation/continuous_batching/cache.py b/src/transformers/generation/continuous_batching/cache.py new file mode 100644 index 000000000000..dfc10859b41e --- /dev/null +++ b/src/transformers/generation/continuous_batching/cache.py @@ -0,0 +1,396 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import deque +from math import floor, sqrt +from typing import Optional, Union + +import torch + +from ...configuration_utils import PretrainedConfig +from ...generation.configuration_utils import GenerationConfig +from ...utils.metrics import attach_tracer, traced +from .classes import RequestState, get_device_and_memory_breakdown, logger + + +@attach_tracer() +class PagedAttentionCache: + def __init__( + self, + config: PretrainedConfig, + generation_config: GenerationConfig, + device: torch.device, + dtype: torch.dtype = torch.float16, + num_requests: int = 100, + layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, + tp_size: Optional[int] = None, + ) -> None: + """Initialize a paged attention cache for efficient memory usage. + + Args: + config: Model configuration + generation_config: Generation configuration containing cache parameters + device: Device for the cache tensors + dtype: Data type for the cache tensors + layer_device_map: Optional mapping of layer indices to devices + initial_prompt_shapes: Optional sample prompts to help calculate optimal cache size + """ + self.dtype = dtype + self.device = device + + # Extract model dimensions + kv_heads = getattr(config, "num_key_value_heads", None) + self.num_key_value_heads: int = kv_heads if kv_heads is not None else config.num_attention_heads + head_dim = getattr(config, "head_dim", None) + self.head_dim: int = head_dim if head_dim is not None else config.hidden_size // config.num_attention_heads + + self.num_hidden_layers = config.num_hidden_layers + self.block_size = getattr(generation_config, "block_size", 32) + + # Handle TP + if tp_size is not None and tp_size > 1: + if self.num_key_value_heads % tp_size != 0: + raise ValueError( + f"Number of key value heads {self.num_key_value_heads} must be divisible by tensor parallel size {tp_size}." + ) + # If the model is using tensor parallelism, we need to adjust the number of heads accordingly. + # self.num_key_value_heads //= tp_size # TODO: why is this commented out? + + # Infer number of blocks and max batch tokens + memory_handler = PagedAttentionMemoryHandler( + block_size=self.block_size, + head_dim=self.head_dim, + num_heads=self.num_key_value_heads, + num_layers=self.num_hidden_layers, + hidden_size=config.hidden_size, + vocab_size=config.vocab_size, + ) + num_blocks, max_batch_tokens = memory_handler.infer_num_blocks_and_max_batch_tokens( + num_blocks=getattr(generation_config, "num_blocks", None), + max_batch_tokens=getattr(generation_config, "max_batch_tokens", None), + max_memory_percent=getattr(generation_config, "max_memory", 0.9), + cache_dtype=self.dtype, + ) + + # Add the infered attributes to the class + self.num_blocks = num_blocks + self.max_batch_tokens = max_batch_tokens + logger.warning(f"PagedAttentionCache initialized with {self.num_blocks = } and {self.max_batch_tokens = } ") + + # Initialize the cache + self.cache_shape = (self.num_key_value_heads, num_blocks, self.block_size, self.head_dim) + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + for idx in range(config.num_hidden_layers): + layer_device = layer_device_map[idx] if layer_device_map is not None else device + new_layer_key_cache = torch.zeros(self.cache_shape, dtype=self.dtype, device=layer_device) + new_layer_value_cache = torch.zeros(self.cache_shape, dtype=self.dtype, device=layer_device) + # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, + # preventing compiled graph breaks when updating the cache. + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + + # Block management data structures + self._free_blocks = deque(range(num_blocks)) + self._block_tables: dict[str, list[int]] = {} + + @traced + def allocate_blocks(self, n_blocks: int, request_id: str) -> list[int]: + """Allocates n_blocks for a given request_id.""" + if len(self._free_blocks) < n_blocks: + return False + + allocated = [] + for _ in range(n_blocks): + allocated.append(self._free_blocks.popleft()) + + if request_id not in self._block_tables: + self._block_tables[request_id] = [] + self._block_tables[request_id].extend(allocated) + return allocated + + @traced + def free_blocks(self, request_id: str) -> None: + """Frees all blocks associated with a request_id.""" + if request_id in self._block_tables: + blocks_to_free = self._block_tables.pop(request_id) + self._free_blocks.extend(blocks_to_free) + else: + logger.info(f"Attempted to free blocks for non-existent request_id: {request_id}") + + def get_num_free_blocks(self) -> int: + """Returns the number of free blocks available.""" + return len(self._free_blocks) + + def get_block_table(self, request_id: str) -> list[int]: + """Returns the block table for a request.""" + return self._block_tables.get(request_id, []) + + @traced + def _get_physical_indices(self, state: RequestState, logical_indices: list[int]) -> list[int]: + """ + Maps logical sequence indices to physical cache indices using the block table, using PyTorch. + + Args: + request_id: The request ID. + logical_indices: A list of logical indices. + + Returns: + A list of physical indices. + + Raises: + ValueError: If no block table is found for the request ID. + IndexError: If a logical index maps to a block index that is out of bounds. + """ + request_id = state.request_id + block_table = self._block_tables.get(request_id) + if not block_table: + raise ValueError(f"No block table found for request {request_id}") + + block_size = self.block_size + physical_indices = [] + + for idx in logical_indices: + block_idx = idx // block_size + block_offset = idx % block_size + + if block_idx >= len(block_table): + raise IndexError( + f"Logical index {idx} maps to block index {block_idx} which is out of bounds " + f"for request {request_id}" + ) + + physical_block_num = block_table[block_idx] + physical_index = physical_block_num * block_size + block_offset + physical_indices.append(physical_index) + + return physical_indices + + @traced + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + read_index, + write_index, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Reshape cache for easier indexing + total_slots = self.num_blocks * self.block_size + k_cache_flat = self.key_cache[layer_idx].view(self.num_key_value_heads, total_slots, self.head_dim) + v_cache_flat = self.value_cache[layer_idx].view(self.num_key_value_heads, total_slots, self.head_dim) + k_cache_flat[:, write_index, :] = key_states[0] + v_cache_flat[:, write_index, :] = value_states[0] + return k_cache_flat[None, :, read_index, :], v_cache_flat[None, :, read_index, :] + + +class PagedAttentionMemoryHandler: + _activation_dtype = torch.bfloat16 + _activation_safety_factor = 2 + _input_dtype = torch.int32 + _upper_bound_max_batch_tokens = 256 + _upper_bound_num_blocks = 4096 + + def __init__( + self, + block_size: int, + head_dim: int, + num_heads: int, + num_layers: int, + hidden_size: int, + vocab_size: int, + ) -> None: + self.block_size = block_size + self.head_dim = head_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.hidden_size = hidden_size + self.vocab_size = vocab_size + + @staticmethod + def get_available_memory(max_memory_percent: float = 1.0) -> int: + _, total, reserved, allocated = get_device_and_memory_breakdown() + available_memory = total - max(allocated, reserved) + available_memory = int(available_memory * max_memory_percent) + return available_memory + + def infer_num_blocks_and_max_batch_tokens( + self, + num_blocks: Optional[int] = None, + max_batch_tokens: Optional[int] = None, + max_memory_percent: float = 0.9, + cache_dtype: torch.dtype = torch.float16, + ) -> tuple[int, int]: + """ + The memory footprint depends on the cache size C and the max batch tokens M in the following way: + Mem = Mem(cache) + Mem(activation) + Mem(static_tensors) + where: + Mem(cache) = 2 * num_heads * head_dim * num_layers * cache_dtype.itemsize * C + Mem(activation) = M * (hidden_size + vocab_size) * activation_dtype.itemsize + Mem(static_tensors) ~= 8M * input_dtype.itemsize + M * C * activation_dtype.itemsize + + Depending on if C or M is given, we use different methods to infer the values (C = num_blocks * block_size) and + since block_size is fixed, num_blocks is the true variable to find. + """ + # If neither num_blocks nor max_batch_tokens are provided, we use a second-order polynomial + if num_blocks is None and max_batch_tokens is None: + num_blocks, max_batch_tokens = self.compute_num_blocks_and_max_batch_tokens( + max_memory_percent, cache_dtype + ) + # If only num_blocks is provided, we infer the max_batch_tokens + elif num_blocks is not None and max_batch_tokens is None: + max_batch_tokens = self.compute_max_batch_tokens(num_blocks, max_memory_percent, cache_dtype) + # If only max_batch_tokens is provided, we infer the num_blocks + elif max_batch_tokens is not None and num_blocks is None: + num_blocks = self.compute_num_blocks(max_batch_tokens, max_memory_percent, cache_dtype) + + # We check if the memory footprint is too large in all cases + available_memory = self.get_available_memory(max_memory_percent) + memory_footprint = self.compute_memory_footprint( + max_batch_tokens=max_batch_tokens, + num_blocks=num_blocks, + cache_dtype=cache_dtype, + ) + if sum(memory_footprint) > available_memory: + raise MemoryError(f"Memory footprint {memory_footprint} is more than available memory {available_memory}") + return num_blocks, max_batch_tokens + + def compute_num_blocks_and_max_batch_tokens( + self, + max_memory_percent: float = 0.9, + cache_dtype: torch.dtype = torch.float16, + m: float = 0.01, + ) -> tuple[int, int]: + """ + If neither M nor C is given, we assume M = m*C so we have to solve a second-order polynomial in C: + Mem = C * 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize + + C * m * (hidden_size + vocab_size) * activation_dtype.itemsize + + C * m * 8 * input_dtype.itemsize + C^2 * m * activation_dtype.itemsize + + We solve for C and then M = m*C. + """ + cache_memory = self.get_available_memory(max_memory_percent) + logger.info(f"Cache memory: {cache_memory}") + + # Compute memory footprints + mem_per_activation_token = m * self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) + mem_per_cache_token = 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize + mem_per_input_token = 8 * m * self._input_dtype.itemsize + logger.info(f"Memory per activation token: {mem_per_activation_token}") + logger.info(f"Memory per cache token: {mem_per_cache_token}") + logger.info(f"Memory per input token: {mem_per_input_token}") + + # Compute second-degree polynomial coefficients + a = m * self._activation_dtype.itemsize + b = mem_per_input_token + mem_per_cache_token + mem_per_activation_token + c = -cache_memory + + # Compute discriminant and greatest solution + discriminant = b**2 - 4 * a * c + if discriminant < 0: + raise ValueError(f"Discriminant is negative: {discriminant = }") + greatest_solution = (-b + sqrt(discriminant)) / (2 * a) + if greatest_solution < 0: + raise ValueError(f"Greatest solution is negative: {greatest_solution = }") + + # Infer number of blocks and max batch tokens + num_blocks = int(greatest_solution) // self.block_size + if num_blocks > self._upper_bound_num_blocks: + logger.warning(f"{num_blocks = } is too large, setting to {self._upper_bound_num_blocks = }") + num_blocks = self._upper_bound_num_blocks + max_batch_tokens = int(greatest_solution * m) + if max_batch_tokens > self._upper_bound_max_batch_tokens: + logger.warning(f"{max_batch_tokens = } is too large, setting to {self._upper_bound_max_batch_tokens = }") + max_batch_tokens = self._upper_bound_max_batch_tokens + return num_blocks, max_batch_tokens + + def compute_max_batch_tokens( + self, + num_blocks: int, + max_memory_percent: float = 0.9, + cache_dtype: torch.dtype = torch.float16, + ) -> int: + """ + If C is given, we have a formula for M: + num = (Mem - C * 2 * num_heads * head_dim * num_layers * cache_dtype.itemsize) + denum = (8 * input_dtype.itemsize + C * activation_dtype.itemsize + (hidden_size + vocab_size) * activation_dtype.itemsize) + M = num / denum + """ + cache_memory = self.get_available_memory(max_memory_percent) + cache_size = num_blocks * self.block_size + # Compute numerator + num = cache_memory + num -= cache_size * 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize + # Compute denominator + denum = 8 * self._input_dtype.itemsize + cache_size * self._activation_dtype.itemsize + denum += (self.hidden_size + self.vocab_size) * self._activation_dtype.itemsize + # Compute max batch tokens and return + return int(num / denum) + + def compute_num_blocks( + self, + max_batch_tokens: int, + max_memory_percent: float = 0.9, + cache_dtype: torch.dtype = torch.float16, + ) -> int: + """ + If M is given, we have a formula for C: + num = Mem - M * (hidden_size + vocab_size) * activation_dtype.itemsize - 8 * M * input_dtype.itemsize + denum = 2 * num_heads * head_dim * num_layers * cache_dtype.itemsize + M * activation_dtype.itemsize + C = num / denum + """ + cache_memory = self.get_available_memory(max_memory_percent) + # Compute numerator + num = cache_memory + num -= self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) * max_batch_tokens + num -= 8 * max_batch_tokens * self._input_dtype.itemsize + # Compute denominator + denum = 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize + denum += max_batch_tokens * self._activation_dtype.itemsize + # Compute cache size and return number of blocks + cache_size = int(num / denum) + return floor(cache_size / self.block_size) + + def compute_memory_footprint( + self, + num_blocks: Optional[int] = None, + max_batch_tokens: Optional[int] = None, + cache_dtype: torch.dtype = torch.float16, + ) -> tuple[int, int, int]: + # Compute activation memory footprint + activation_memory_footprint = self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) + activation_memory_footprint *= max_batch_tokens + # Compute cache memory footprint if num_blocks is provided + if num_blocks is not None: + cache_size = num_blocks * self.block_size + bytes_per_token = 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize + cache_memory_footprint = cache_size * bytes_per_token + else: + cache_memory_footprint = -1 + # Compute static tensors memory footprint if num_blocks and max_batch_tokens is provided + if num_blocks is not None and max_batch_tokens is not None: + static_memory_footprint = sum( + [ + 3 * max_batch_tokens * self._input_dtype.itemsize, # input_ids, position_ids, output_ids + max_batch_tokens * cache_size * self._activation_dtype.itemsize, # attention_mask + 2 * max_batch_tokens * self._input_dtype.itemsize, # cumulative_seqlens_qk (we remove the +1 to M) + 3 * max_batch_tokens * self._input_dtype.itemsize, # write_index, read_index, logits_indices + ] + ) + else: + static_memory_footprint = -1 + return activation_memory_footprint, cache_memory_footprint, static_memory_footprint diff --git a/src/transformers/generation/continuous_batching/classes.py b/src/transformers/generation/continuous_batching/classes.py new file mode 100644 index 000000000000..f2c3a9eda455 --- /dev/null +++ b/src/transformers/generation/continuous_batching/classes.py @@ -0,0 +1,210 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Optional + +import torch + +from ...utils.logging import logging +from ...utils.metrics import traced + + +# We centralize the logger here to coordinate between logging and progress bar +logger = logging.getLogger("ContinuousBatchingLogger") +logger.setLevel(logging.INFO) + + +@staticmethod +def get_device_and_memory_breakdown() -> tuple[torch.device, int, int, int]: + if torch.cuda.is_available(): + device = torch.device("cuda") + torch.cuda.empty_cache() + torch.cuda.synchronize() + total_memory = torch.cuda.get_device_properties(device).total_memory + reserved_memory = torch.cuda.memory_reserved(device) + allocated_memory = torch.cuda.memory_allocated(device) + elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): + device = torch.device("mps") + # MPS memory reporting (PyTorch 2.0+) + total_memory = torch.mps.driver_allocated_memory() + allocated_memory = total_memory - torch.mps.recommended_max_memory() + reserved_memory = 0 # MPS does not track reserved separately + else: + device = torch.device("cpu") + total_memory = None + reserved_memory = 0 + allocated_memory = 0 + return device, total_memory, reserved_memory, allocated_memory + + +class RequestStatus(Enum): + """Status of a generation request through its lifecycle.""" + + PENDING = "pending" + PREFILLING = "prefilling" + PREFILLING_SPLIT = "prefilling_split" + SPLIT_PENDING_REMAINDER = "split_pending_remainder" + DECODING = "decoding" + FINISHED = "finished" + FAILED = "failed" + + +@dataclass +class GenerationOutput: + """Tracks the output of a generation request. + + Attributes: + request_id (str): The ID of the generation request. + prompt_ids (list[int]): The IDs of the prompt tokens. + generated_tokens (list[int]): The generated tokens. + logprobs (list[float]): The log probabilities of the generated tokens. + error (Optional[str]): Any error message associated with the request. When None, the request was successful. + status (RequestStatus): The status of the request. + created_time (float): The time the request was created. + next_token (Optional[int]): The next token to be generated. + """ + + request_id: str + prompt_ids: list[int] = field(default_factory=list) + generated_tokens: list[int] = field(default_factory=list) + logprobs: list[float] = field(default_factory=list) + error: Optional[str] = None + status: RequestStatus = RequestStatus.PENDING + created_time: float = field(default_factory=time.time) + next_token: Optional[int] = field(default_factory=int) + + +@dataclass +class RequestState: + """Tracks the state of a generation request through its lifecycle. + + Attributes: + request_id (str): The ID of the generation request. + full_prompt_ids (list[int] | None): The tokens IDs of the full prompt. + prompt_ids (list[int] | None): The tokens IDs currently being processed. + remaining_prompt_ids (list[int]): The tokens IDs remaining to be processed (for split requests). + static_outputs (list[int]): The generated tokens. + allocated_blocks (list[int]): The identifiers of the allocated blocks to the request. + position_offset (int): The current position in the sequence for position_ids. + status (RequestStatus): The status of the request: can be one of PENDING, PREFILLING, PREFILLING_SPLIT, + SPLIT_PENDING_REMAINDER, DECODING, FINISHED, FAILED + max_new_tokens (int): The maximum number of new tokens to generate. + eos_token_id (int): The ID of the end-of-sequence token. + created_time (float): The time the request was created. + error (Optional[str]): Any error message associated with the request. When None, has had no error yet. + next_token (Optional[str]): The next token to be generated. + """ + + # Required fields + request_id: str + full_prompt_ids: Optional[list[int]] = None # Full initial prompt + prompt_ids: Optional[list[int]] = None # Tokens IDs currently being processed (initial + generated) + remaining_prompt_ids: list[int] = field(default_factory=list) # For split requests, prefill left to process + static_outputs: list[int] = field(default_factory=list) # Generated tokens + allocated_blocks: list[int] = field(default_factory=list) # Block IDs allocated to the request + position_offset: int = 0 # Current position in the sequence for position_ids + _status: RequestStatus = RequestStatus.PENDING # Status of the request, hidden behind a property + max_new_tokens: int = 20 # Maximum number of new tokens to generate + eos_token_id: int = -1 # ID of the end-of-sequence token + created_time: float = field(default_factory=time.time) # Time the request was created + error: Optional[str] = None # Error message if the request failed + next_token: Optional[str] = None # Next token to be generated + lifespan: tuple[float, float] = (-1, -1) # (time request was no longer pending, time request finished) + + @property + def status(self) -> RequestStatus: + return self._status + + @status.setter + def status(self, value: RequestStatus): + if self._status == RequestStatus.PENDING: + self.lifespan = (time.time(), -1) + elif value == RequestStatus.FINISHED: + self.lifespan = (self.lifespan[0], time.time()) + self.log_end_of_request() + self._status = value + + def log_end_of_request(self): + prefill_len = len(self.full_prompt_ids) + decode_len = self.generated_len() + start_time = self.lifespan[0] - self.created_time + end_time = self.lifespan[1] - self.created_time + logger.info( + f"Request {self.request_id} finished: {prefill_len = } {decode_len = } {start_time = } {end_time = }" + ) + + def current_len(self) -> int: + """Get the current length of the sequence (prompt + generated tokens).""" + return self.position_offset + + def generated_len(self) -> int: + """Get the number of tokens generated so far.""" + return len(self.static_outputs) + + # TODO: this logic seems one token off, check it out + @traced + def update_with_token(self, token_id: int) -> bool: + """Update the request with a newly generated token and check for completion. + + Args: + token_id: The token ID to add to the output sequence + + Returns: + bool: True if the request is now complete, False otherwise + """ + # Only update if we're in decoding state + if self.status != RequestStatus.DECODING: + return False + + is_eos = token_id == self.eos_token_id and self.eos_token_id != -1 + is_max_len = self.generated_len() >= self.max_new_tokens + + # Only add the token if we're not finishing due to max length + # (EOS tokens should still be added to the output) + if not (is_max_len and not is_eos): + self.static_outputs.extend([token_id]) + + if is_eos or is_max_len: + self.status = RequestStatus.FINISHED + return True + return False + + def __repr__(self): + msg = [ + f"request_id={self.request_id}", + f"status={self._status}", + f"out_tokens={self.generated_len()}", + f"query_length={len(self.prompt_ids)}", + f"remaining_tokens={len(self.remaining_prompt_ids)}", + f"kv_length={self.position_offset}", + f"full_prompt_lenght={len(self.full_prompt_ids)}", + f"allocated_blocks={self.allocated_blocks}", + f"generated_tokens={self.static_outputs}", + ] + return "RequestState(\n\t" + ",\n\t".join(msg) + "\n)" + + def to_generation_output(self): + """Convert the request state to a GenerationOutput object.""" + return GenerationOutput( + request_id=self.request_id, + prompt_ids=self.full_prompt_ids, + status=self.status, + generated_tokens=self.static_outputs, + logprobs=[], + error=self.error, + next_token=self.next_token, + ) diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py new file mode 100644 index 000000000000..4b6775141362 --- /dev/null +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -0,0 +1,842 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import queue +import threading +from dataclasses import dataclass +from functools import partial +from typing import Optional + +import torch +from tokenizers.decoders import DecodeStream +from torch import nn +from tqdm import tqdm + +from ...configuration_utils import PretrainedConfig +from ...generation.configuration_utils import GenerationConfig +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils.logging import logging +from ...utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced +from .cache import PagedAttentionCache +from .classes import GenerationOutput, RequestState, RequestStatus, get_device_and_memory_breakdown, logger +from .scheduler import SCHEDULER_MAPPING, FIFOScheduler, Scheduler + + +@dataclass +class PagedAttentionArgs: + input_ids: torch.Tensor + attention_mask: Optional[torch.Tensor] + position_ids: torch.Tensor + cumulative_seqlens_q: torch.Tensor + cumulative_seqlens_k: torch.Tensor + max_seqlen_q: int + max_seqlen_k: int + write_index: torch.Tensor + read_index: torch.Tensor + logits_indices: torch.Tensor + block_tables: dict[str, list[int]] + cache: PagedAttentionCache + use_cache: bool = False + + +# Continuous Batch Processor (Internal Logic) +@attach_tracer() +class ContinuousBatchProcessor: + def __init__( + self, + cache: PagedAttentionCache, + config: PretrainedConfig, + generation_config: GenerationConfig, + input_queue: queue.Queue, + output_queue: queue.Queue, + stop_event: threading.Event, + model_device: torch.device, + model_dtype: torch.dtype, + scheduler: Scheduler, + streaming: bool = False, + manual_eviction: bool = False, + slice_inputs: bool = True, # TODO: remove this once parity is ensured + ): + """Initialize the continuous batch processor. + + Args: + cache: The paged attention cache to use + generation_config: The generation configuration + input_queue: Queue for incoming requests + output_queue: Queue for outgoing results + stop_event: Event to signal processing should stop + model_device: Device for model inputs/outputs + model_dtype: Data type for model inputs/outputs + streaming: Whether to stream tokens as they're generated + """ + self.cache = cache + self.config = config + self.generation_config = generation_config + self.input_queue = input_queue + self.output_queue = output_queue + self.stop_event = stop_event + self.model_device = model_device + self.model_dtype = model_dtype + self.scheduler = scheduler + self.streaming = streaming + self.manual_eviction = manual_eviction + self.slice_inputs = slice_inputs + + self.requests_in_batch: list[RequestState] = [] + + # Set up metrics collector + self.max_batch_tokens = cache.max_batch_tokens + self.metrics = ContinuousBatchProcessorMetrics(cache.max_batch_tokens) + + self.setup_static_tensors() + + self.tokenizer = PreTrainedTokenizerFast.from_pretrained(self.config._name_or_path) + self.decode_stream = DecodeStream(skip_special_tokens=True) + + def return_attention_mask(self) -> bool: + return self.config._attn_implementation != "paged_attention" # we set `is_causal` to True in paged call + + @traced(standalone=True) + def setup_static_tensors(self): + T = self.max_batch_tokens + max_token_budget = self.cache.num_blocks * self.cache.block_size + tensor_metadata = {"dtype": torch.int32, "device": self.model_device} + # Prepare empty tensors + self.tensor_metadata = tensor_metadata + self.input_ids = torch.empty((1, T), **tensor_metadata) + self.position_ids = torch.empty((1, T), **tensor_metadata) + self.cumulative_seqlens_q = torch.empty((T + 1,), **tensor_metadata) + self.cumulative_seqlens_k = torch.empty((T + 1,), **tensor_metadata) + self.write_index = torch.empty((T,), **tensor_metadata) + self.read_index = torch.empty((max_token_budget,), **tensor_metadata) + self.logits_indices = torch.empty((T,), **tensor_metadata) + self.max_seqlen_q = 0 + self.max_seqlen_k = 0 + self.output_ids = torch.empty((1, T), **tensor_metadata) + # Since attenention_mask is not always needed, we only allocate it if it is needed + if self.return_attention_mask(): + self.attention_mask = torch.empty( + (1, 1, T, max_token_budget), dtype=self.model_dtype, device=self.model_device + ) + else: + self.attention_mask = None + # Initialize the tensors by pretending they are in full use + self.actual_tokens = T + self.cache_used = max_token_budget + self.reset_static_tensors() + # Reset stats to 0 + self.actual_tokens = 0 + self.cache_used = 0 + + @traced + @torch.no_grad() + def reset_static_tensors(self): + """Reset static tensors for the next batch.""" + # Compute the slice to reset + t = self.actual_tokens if self.slice_inputs else self.write_index.size(0) + c = self.cache_used if self.slice_inputs else self.read_index.size(0) + # Reset the tensors + self.input_ids[:, :t].zero_() + self.position_ids[:, :t].zero_() + self.cumulative_seqlens_q[: t + 1].zero_() + self.cumulative_seqlens_k[: t + 1].zero_() + self.write_index[:t].fill_(-1) + self.read_index[:c].fill_(-1) + self.logits_indices[:t].fill_(-1) + self.max_seqlen_q = 0 + self.max_seqlen_k = 0 + self.output_ids[:, :t].fill_(-1) + if self.attention_mask is not None: + self.attention_mask[:, :, :t, :c].fill_(torch.finfo(self.model_dtype).min) + + def get_model_kwargs(self) -> PagedAttentionArgs: + """Get model keyword arguments for the current batch.""" + # Compute the slice to return + t = self.actual_tokens if self.slice_inputs else self.write_index.size(0) + c = self.cache_used if self.slice_inputs else self.read_index.size(0) + # Prepare the kwargs + kwargs = { + "input_ids": self.input_ids[:, :t], + "attention_mask": self.attention_mask, + "position_ids": self.position_ids[:, :t], + "cu_seq_lens_q": self.cumulative_seqlens_q[: t + 1], + "cu_seq_lens_k": self.cumulative_seqlens_k[: t + 1], + "write_index": self.write_index[:t], + "read_index": self.read_index[:c], + "logits_indices": self.logits_indices[:t], + "max_seqlen_q": self.max_seqlen_q, + "max_seqlen_k": self.max_seqlen_k, + "block_tables": self.cache._block_tables, + "cache": self.cache, + "use_cache": False, + } + # If the attention mask is not None, we slice it as the others + if self.attention_mask is not None: + kwargs["attention_mask"] = self.attention_mask[:, :, :t, :c] + return kwargs + + def __repr__(self): + return ( + f"ContinuousBatchProcessor(input_queue={self.input_queue}, output_queue={self.output_queue}, active_requests={self.scheduler.active_requests}, waiting_requests={self.scheduler.waiting_requests})" + + self.get_model_kwargs().__repr__() + ) + + @traced + def _get_new_requests(self): + """Pull new requests from the input queue and add to waiting list.""" + while not self.input_queue.empty(): + try: + state = self.input_queue.get_nowait() + if state is None: # Sentinel value + continue + self.scheduler.add_waiting_request(state) + + except queue.Empty: + break + except Exception as e: + logger.error(f"Error processing new request: {e}", exc_info=True) + state: RequestState = locals().get("state") + if state is not None: + self._handle_request_error(e, state) + + @traced + def _handle_request_error(self, error, state: RequestState): + """Handle general request processing error.""" + state.status = RequestStatus.FAILED + state.error = str(error) + + # Include any generated tokens if this is an active request + if isinstance(state.request_id, str): + state.static_outputs = self.scheduler.get_active_request_static_outputs(state.request_id) + else: + state.static_outputs = [] + + self.metrics.record_request_completion(state.created_time, state.request_id) + self.output_queue.put(state.to_generation_output()) + + @traced + def prepare_next_batch(self): + """Prepare tensors and metadata for the next model forward pass.""" + # Get new requests from the queue + self._get_new_requests() + if not self.scheduler.has_pending_requests(): + return None + + self.metrics.record_queue_metrics(len(self.scheduler.active_requests), len(self.scheduler.waiting_requests)) + + self.requests_in_batch = self.scheduler.schedule_batch(self.max_batch_tokens) + if not self.requests_in_batch: + return None + + # Get the request objects for this batch + self.reset_static_tensors() + position_ids = [] + input_ids = [] + read_index = [] + write_index = [] + cumulative_seqlens_q = [0] + cumulative_seqlens_k = [0] + logits_indices = [] + self.metrics.record_batch_metrics(self.requests_in_batch) + + for state in self.requests_in_batch: + next_input_ids = state.prompt_ids + input_ids.extend(next_input_ids) + past_length = state.position_offset + query_length = len(next_input_ids) + key_length = query_length + past_length + cache_index = list(range(key_length)) + + positions_to_add = cache_index[past_length:] + read_indices = self.cache._get_physical_indices(state, cache_index) + write_indices = read_indices[-query_length:] + + position_ids.extend(positions_to_add) + read_index.extend(read_indices) + write_index.extend(write_indices) + cumulative_seqlens_q.append(cumulative_seqlens_q[-1] + query_length) + cumulative_seqlens_k.append(cumulative_seqlens_k[-1] + key_length) + if len(state.remaining_prompt_ids) == 0: + logits_indices.append(cumulative_seqlens_q[-1] - 1) + self.max_seqlen_q = max(self.max_seqlen_q, query_length) + self.max_seqlen_k = max(self.max_seqlen_k, key_length) + state.position_offset += query_length + + logger.debug( + f"Scheduled: {len(self.requests_in_batch)}, Waiting: {len(self.scheduler.waiting_requests)}, " + f"Active: {len(self.scheduler.active_requests)}. cum Q: {cumulative_seqlens_q[-1]}. " + f"cum KV: {cumulative_seqlens_k[-1]}, free blocks: {self.cache.get_num_free_blocks()}" + ) + self._build_tensors( + input_ids, + position_ids, + read_index, + write_index, + cumulative_seqlens_q, + cumulative_seqlens_k, + logits_indices, + ) + + self.metrics.record_kv_cache_memory_metrics(self.cache) + + @traced + def _build_tensors( + self, + input_ids, + position_ids, + read_index, + write_index, + cumulative_seqlens_q, + cumulative_seqlens_k, + logits_indices, + ): + to_tensor = partial(torch.tensor, **self.tensor_metadata) + self.input_ids[:, : len(input_ids)] = to_tensor(input_ids) + self.position_ids[:, : len(position_ids)] = to_tensor(position_ids) + self.write_index[: len(write_index)] = to_tensor(write_index) + self.read_index[: len(read_index)] = to_tensor(read_index) + self.cumulative_seqlens_q[: len(cumulative_seqlens_q)] = to_tensor(cumulative_seqlens_q) + self.cumulative_seqlens_k[: len(cumulative_seqlens_k)] = to_tensor(cumulative_seqlens_k) + self.logits_indices[: len(logits_indices)] = to_tensor(logits_indices) + + self.actual_tokens = len(input_ids) + self.cache_used = len(read_index) + + min_value = torch.finfo(self.model_dtype).min + if self.attention_mask is not None: + for i in range(len(cumulative_seqlens_q) - 1): + if ( + cumulative_seqlens_q[i + 1] - cumulative_seqlens_q[i] + < cumulative_seqlens_k[i + 1] - cumulative_seqlens_k[i] + and cumulative_seqlens_q[i + 1] - cumulative_seqlens_q[i] >= 1 + ): + diagonal = ( + cumulative_seqlens_k[i + 1] - (cumulative_seqlens_q[i + 1] - cumulative_seqlens_q[i]) + 1 + ) + diagonal = diagonal - cumulative_seqlens_k[i] + else: + diagonal = 1 + query_range = slice(cumulative_seqlens_q[i], cumulative_seqlens_q[i + 1]) + key_range = slice(cumulative_seqlens_k[i], cumulative_seqlens_k[i + 1]) + + mask = torch.triu( + torch.full( + self.attention_mask[..., query_range, key_range].shape, + min_value, + dtype=self.model_dtype, + device=self.model_device, + ), + diagonal=diagonal, + ) + self.attention_mask[..., query_range, key_range] = mask + + @traced + def _sync(self): + if self.output_ids is not None: + try: + out = self.output_ids.tolist()[0] # should be the only synch we do + except Exception: + out = [0, 1] + else: + out = [0, 0] + return out + + @traced + def _maybe_send_output(self, state: RequestState, token: int): + """Send output to the queue based on streaming mode and request state.""" + if self.streaming: + state.next_token = self.decode_stream.step(self.tokenizer, state.static_outputs[-1]) + self.output_queue.put(state.to_generation_output()) + elif state.status == RequestStatus.FINISHED: + self.output_queue.put(state.to_generation_output()) + + @traced + def update_batch(self): + """Update request states based on generated tokens.""" + out_tokens = self._sync() + finished_request_ids = [] + for i, state in enumerate(self.requests_in_batch): + req_id = state.request_id + if len(state.remaining_prompt_ids) == 0: + self.metrics.record_ttft_metric(state.created_time, state.request_id) + state.status = RequestStatus.DECODING + token = out_tokens[self.logits_indices[i]] + state.prompt_ids = [token] + if state.update_with_token(token): + self.metrics.record_request_completion(state.created_time, state.request_id) + self.scheduler.finish_request(state.request_id, evict_from_cache=(not self.manual_eviction)) + finished_request_ids.append(req_id) + self._maybe_send_output(state, token) + elif state.status == RequestStatus.PREFILLING_SPLIT: + state.status = RequestStatus.SPLIT_PENDING_REMAINDER + if self.cache.get_num_free_blocks() == 0: + raise ValueError("No more free blocks") + + @traced + def has_pending_requests(self) -> bool: + """Check if there are any active or waiting requests.""" + return self.scheduler.has_pending_requests() + + @traced + def handle_batch_error(self, error): + """Handle errors during batch processing.""" + failed_reqs = self.requests_in_batch + for req in failed_reqs: + self._handle_request_error(error, req) + self.scheduler.finish_request(req.request_id) + + @traced + def fail_all_requests(self, error): + """Fail all active requests with the given error. + + Args: + error: The error to report in the failure message + """ + + requests = list(self.scheduler.active_requests.values()) + for state in requests: + self._handle_request_error(error, state) + self.scheduler.finish_request(state.request_id) + + # Also fail any requests in the waiting queue + for req_id in list(self.scheduler.waiting_requests.keys()): + state = self.scheduler.waiting_requests.pop(req_id) + self._handle_request_error(error, state) + + # Clear the ordering queue + self.scheduler.waiting_requests_order.clear() + + +# Manager Class (User Interface) +@attach_tracer() +class ContinuousBatchingManager: + """Manager for handling continuous batching of generation requests. + + This class provides the user interface for submitting generation requests, + retrieving results, and managing the background generation thread. + """ + + def __init__( + self, + model, + generation_config: GenerationConfig, + manual_eviction: bool = False, + max_queue_size=0, + streaming: bool = True, + slice_inputs: bool = True, + ): + """Initialize the continuous batching manager. + + Args: + model: The language model for generation + generation_config: Configuration for generation parameters + max_queue_size: Maximum size of the request queue (0 = unlimited) + streaming: Whether to stream tokens as they are generated + """ + self.model = model.eval() + generation_config = model.generation_config if generation_config is None else generation_config + self.generation_config = generation_config + self.input_queue = queue.Queue(maxsize=max_queue_size) + self.output_queue = queue.Queue() + self.stop_event = threading.Event() + self.streaming = streaming + self.log_prob_generation = getattr(generation_config, "log_prob_generation", False) + self._generation_thread = None + self._request_counter = 0 + self._request_lock = threading.Lock() + self.model.generation_config.top_p = None + self.do_sample = getattr(generation_config, "do_sample", True) + self.logit_processor = self.model._get_logits_processor(generation_config) + self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", True) + self.profile = getattr(generation_config, "profile", False) + self.manual_eviction = manual_eviction + self.batch_processor: Optional[ContinuousBatchProcessor] = None + self.decode_stream = DecodeStream(skip_special_tokens=True) + self.slice_inputs = slice_inputs + + @traced + def start(self): + """Start the background generation thread.""" + if self._generation_thread is not None and self._generation_thread.is_alive(): + logger.warning("Manager thread is already running.") + return + + self._result_queue = queue.Queue() + self._generation_thread = threading.Thread(target=self._run_generation_loop) + self._generation_thread.start() + logger.info("Continuous batching manager started.") + + def is_running(self): + """Check if the background generation thread is running.""" + return self._generation_thread is not None and self._generation_thread.is_alive() + + def stop(self, block: bool = False, timeout: Optional[float] = None): + """Signal the background thread to stop. + + Args: + block: Whether to wait for the thread to stop + timeout: Maximum time to wait for the thread to stop + """ + if self._generation_thread is None: + logger.warning("Manager not started.") + return + + if not self.stop_event.is_set(): + self.stop_event.set() + logger.info("Stopping continuous batching manager...") + + if block: + self.join(timeout) + + def join(self, timeout: Optional[float] = None): + """Wait for the background thread to finish. + + Args: + timeout: Maximum time to wait for the thread to stop + """ + if self._generation_thread is not None: + self._generation_thread.join(timeout=timeout) + if self._generation_thread.is_alive(): + logger.warning("Generation thread did not exit after join timeout.") + else: + logger.info("Continuous Batching Manager stopped.") + self._generation_thread = None + + def add_request( + self, input_ids: list[int], request_id: Optional[str] = None, max_new_tokens: Optional[int] = None + ) -> str: + """Add a new generation request to the queue. + + Args: + input_ids: Input token IDs to use as prompt + request_id: Optional custom request ID (auto-generated if None) + **kwargs: Additional generation parameters + + Returns: + str: The request ID + """ + if request_id is None: + with self._request_lock: + request_id = f"req_{self._request_counter}" + self._request_counter += 1 + + max_new_tokens = self.generation_config.max_new_tokens if max_new_tokens is None else max_new_tokens + + state = RequestState( + request_id=request_id, + prompt_ids=list(input_ids), + full_prompt_ids=list(input_ids), + max_new_tokens=max_new_tokens, + eos_token_id=self.generation_config.eos_token_id, + ) + + # Use block=True with timeout to handle backpressure if queue is full + self.input_queue.put(state, block=True, timeout=10) # XXX: pass timeout as fn arg? + logger.debug(f"Added request {request_id} to queue.") + return request_id + + def add_requests(self, inputs: list[list[int]], **kwargs): + for i, input_ids in enumerate(inputs): + # Assign a predictable request ID for ordering results later + req_id = f"batch_req_{i}" + self.add_request(input_ids, request_id=req_id, **kwargs) + + def get_result(self, timeout=None) -> Optional[GenerationOutput]: + """Retrieve one result from the output queue. + + Args: + timeout: Maximum time to wait for a result + + Returns: + Optional[Dict]: The result data or None if timeout + """ + if self._generation_thread is None and self.output_queue.empty(): + return None + try: + result = self.output_queue.get(block=True, timeout=timeout) + logger.debug(f"Retrieved result for request {result.request_id}") + return result + except queue.Empty: + return None + + def __iter__(self): + """Iterate over results as they become available.""" + while ( + self._generation_thread is not None and self._generation_thread.is_alive() or not self.output_queue.empty() + ): + result = self.get_result(timeout=0.1) # allow the model to run for 10 seconds + if result is not None: + yield result + + @traced + def warmup(self, batch_processor): + stream = torch.cuda.Stream(device=self.model.device) + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + # Warmup the model with a dummy forward pass + self._generation_step(batch_processor) + torch.cuda.current_stream().wait_stream(stream) + + self.graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self.graph, stream=stream): + self._generation_step(batch_processor) + + @traced + # @torch.compile + def _generation_step(self, batch_processor: ContinuousBatchProcessor): + """Perform a single generation step. This is cuda graphed""" + batch_data = batch_processor.get_model_kwargs() + with torch.no_grad(): + logits = self._model_forward(batch_data) + if self.log_prob_generation: + batch_processor.output_probs.copy_(logits) # TODO + probs = self._process_logit(batch_data, logits) + self._sample(batch_processor, probs) + + @traced(span_name="model_forward") + def _model_forward(self, batch_data): + return self.model(**batch_data).logits + + @traced(span_name="logit_processing") + def _process_logit(self, batch_data, logits): + # Pass continuous batching context to logits processor if it supports it. TODO we should find a way to make this a little bit cleaner! + if hasattr(self.logit_processor, "set_continuous_batching_context"): + self.logit_processor.set_continuous_batching_context( + batch_data["logits_indices"], batch_data["cu_seq_lens_q"] + ) + return self.logit_processor(batch_data["input_ids"], logits) + + @traced(span_name="sampling") + def _sample(self, batch_processor: ContinuousBatchProcessor, probs): + if self.do_sample: # sample + probs = nn.functional.softmax(probs, dim=-1) + next_tokens = torch.multinomial(probs[0], num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + tokens = next_tokens.size(1) + batch_processor.output_ids[:, :tokens].copy_(next_tokens) + + def _run_generation_loop(self): + """Main processing loop running in the background thread.""" + batch_processor = None + try: + paged_attention_cache = PagedAttentionCache( + self.model.config, + self.generation_config, + self.model.device, + self.model.dtype, + num_requests=len(self.input_queue.queue), + tp_size=getattr(self.model, "_tp_size", None), # Use model's actual TP setting + ) + + scheduler = None + if hasattr(self.generation_config, "scheduler"): + scheduler = SCHEDULER_MAPPING.get(self.generation_config.scheduler, None) + if scheduler is None: + logger.warning(f"Scheduler '{scheduler}' not found. Defaulting to FIFO.") + scheduler = FIFOScheduler + else: + # Default to fifo + scheduler = FIFOScheduler + + batch_processor = ContinuousBatchProcessor( + paged_attention_cache, + self.model.config, + self.generation_config, + self.input_queue, + self.output_queue, + self.stop_event, + self.model.device, + self.model.dtype, + scheduler(paged_attention_cache, self.manual_eviction), + self.streaming, + self.manual_eviction, + slice_inputs=self.slice_inputs, + ) + self.batch_processor = batch_processor + self.current_batch = 0 + while (not self.stop_event.is_set()) or batch_processor.has_pending_requests(): + self._inner_generation_loop(batch_processor) + self.current_batch += 1 + + except Exception as e: + logger.error(f"Error in generation loop: {e}", exc_info=True) + self._handle_critical_error(e, batch_processor) + finally: + logger.info("Generation loop finished.") + + @traced(span_name="generation_loop") + def _inner_generation_loop(self, batch_processor: ContinuousBatchProcessor): + if torch.cuda.is_available(): + torch.cuda.synchronize() + batch_processor.prepare_next_batch() + device, total, reserved, allocated = get_device_and_memory_breakdown() + logger.debug(f"[Memory] Device: {device}, Total: {total}, Reserved: {reserved}, Allocated: {allocated}") + if torch.cuda.is_available() and self.use_cuda_graph: + if self.current_batch == 0: + self.warmup(batch_processor) + elif hasattr(self, "graph"): + try: + self._graph_replay() + except Exception as e: + logger.error(f"Model forward pass failed: {e}", exc_info=True) + batch_processor.handle_batch_error(e) + return + else: + self._generation_step(batch_processor) + else: + self._generation_step(batch_processor) + if torch.cuda.is_available(): + torch.cuda.synchronize() + batch_processor.update_batch() + + @traced(span_name="graph_replay") + def _graph_replay(self): + self.graph.replay() + + @traced + def _handle_critical_error(self, error, batch_processor: Optional[ContinuousBatchProcessor]): + """Handle critical errors that terminate the generation loop.""" + # Signal stop + self.stop_event.set() + + # Fail pending requests in input queue + try: + while True: + req_data = self.input_queue.get_nowait() + if batch_processor is not None: + batch_processor._handle_request_error(error, req_data) + except queue.Empty: + pass + + # Fail active requests + if batch_processor is not None: + batch_processor.fail_all_requests(error) + + @traced + def evict_request_from_cache(self, request_id: str): + """Evict a request from the cache. It is assumed that the request is already finished.""" + if not self.manual_eviction: + raise RuntimeError("Manual eviction is not enabled for this manager.") + if self.batch_processor is not None: + self.batch_processor.scheduler.finish_request(request_id) + + +class ContinuousMixin: + """Mixin class for models to add continuous batching capabilities.""" + + def init_continuous_batching( + self, + generation_config: Optional[GenerationConfig] = None, + manual_eviction: bool = False, + max_queue_size: int = 0, + streaming: bool = False, + slice_inputs: bool = True, + ) -> ContinuousBatchingManager: + """Initialize a manager for continuous batching inference. + + Args: + generation_config: Custom generation configuration + max_queue_size: Maximum size of the input request queue + streaming: Whether to stream tokens as they are generated + + Returns: + `ContinuousBatchingManager`: The manager instance to add requests and retrieve results. + """ + if not hasattr(self, "config") or not hasattr(self, "device") or not hasattr(self, "dtype"): + raise AttributeError("Model must have 'config', 'device', and 'dtype' attributes.") + + gen_config = generation_config if generation_config is not None else self.generation_config + if gen_config is None: + raise ValueError("A GenerationConfig must be provided or set in the model.") + + if gen_config.eos_token_id is None: + logger.warning("`eos_token_id` not set in GenerationConfig. Setting to -1 (disabled).") + gen_config.eos_token_id = -1 + + # Create and return the manager + return ContinuousBatchingManager( + model=self, + generation_config=gen_config, + manual_eviction=manual_eviction, + max_queue_size=max_queue_size, + streaming=streaming, + slice_inputs=slice_inputs, + ) + + @traced + @torch.inference_mode() + def generate_batch( + self, + inputs: list[list[int]], + generation_config: Optional[GenerationConfig] = None, + progress_bar: bool = True, + slice_inputs: bool = True, + **kwargs, + ) -> list[list[int]]: + """Generate sequences for a batch of prompts using continuous batching. + + Args: + inputs: List of input token sequences (prompts) + generation_config: Optional generation configuration + **kwargs: Additional generation parameters + + Returns: + `list[list[int]]`: A list containing the generated sequences (including prompt tokens + if not handled otherwise) for each input prompt, in the same order. + Returns an empty list `[]` for requests that failed. + """ + if not inputs: + return [] + if logger.getEffectiveLevel() <= logging.DEBUG: + logger.warning("Progress bar is disabled when logger level is less than DEBUG") + progress_bar = False + + # Initialize manager with the batch inputs + manager = self.init_continuous_batching(generation_config=generation_config, slice_inputs=slice_inputs) + manager.start() + results = {} + num_requests = len(inputs) + try: + from tqdm.contrib.logging import logging_redirect_tqdm + + with logging_redirect_tqdm([logger]): + with tqdm( + total=num_requests, + disable=(not progress_bar), + desc=f"Solving {num_requests} requests", + unit="request", + ) as pbar: + manager.add_requests(inputs, **kwargs) + finished_count = 0 + while finished_count < num_requests: + result = manager.get_result(timeout=1) + if result: + req_id = result.request_id + if result.status == RequestStatus.FINISHED: + results[req_id] = result + finished_count += 1 + pbar.update(1) + logger.debug(manager.batch_processor.tokenizer.decode(result.generated_tokens)) + else: + if not manager.is_running(): + logger.error("Generation thread terminated unexpectedly.") + break + + except Exception as e: + logger.error(f"Error during batch generation: {e}", exc_info=True) + finally: + manager.stop(block=True, timeout=5.0) + return results diff --git a/src/transformers/generation/continuous_batching/scheduler.py b/src/transformers/generation/continuous_batching/scheduler.py new file mode 100644 index 000000000000..9f612c9380ff --- /dev/null +++ b/src/transformers/generation/continuous_batching/scheduler.py @@ -0,0 +1,314 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +from collections import deque + +from ...utils.metrics import attach_tracer, traced +from .cache import PagedAttentionCache +from .classes import RequestState, RequestStatus + + +class Scheduler(ABC): + """ + Abstract base class for scheduling requests in the continuous batch processor. + It is expected that cache allocation and scheduling logic will be implemented in subclasses. + """ + + def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = False): + self.active_requests: dict[str, RequestState] = {} + self.waiting_requests: dict[str, RequestState] = {} + self.waiting_requests_order: deque[str] = deque() + self.cache = cache + self.retain_cache_on_finish = retain_cache_on_finish + + @abstractmethod + def add_waiting_request(self, state: RequestState): + """Add a request to the waiting list.""" + pass + + @abstractmethod + def schedule_batch(self, token_budget: int) -> list[RequestState]: + pass + + @traced + def has_pending_requests(self) -> bool: + """Check if there are requests ready to be processed.""" + return len(self.active_requests) or len(self.waiting_requests) + + @abstractmethod + def finish_request(self, request_id: str, evict_from_cache: bool = True): + """Finish processing a request and free its allocated blocks.""" + pass + + @traced + def get_active_request_static_outputs(self, request_id: str) -> list[int]: + if request_id in self.active_requests: + return self.active_requests[request_id].static_outputs + return [] + + +@attach_tracer() +class FIFOScheduler(Scheduler): + def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = False, safety_margin: float = 0.0): + super().__init__(cache, retain_cache_on_finish) + self.safety_margin = safety_margin + + @traced + def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int): + # 1. we check that the occupancy is less than the requested length + # 2. we allocate enough blocks to cover the requested length + current_len = state.current_len() + occupancy = len(state.allocated_blocks) * self.cache.block_size - current_len + if occupancy < len_next_tokens or (len(state.allocated_blocks) == 0): + blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1 + allocated = self.cache.allocate_blocks(blocks_needed, state.request_id) + if not allocated: + return False + state.allocated_blocks.extend(allocated) + return True + + @traced(span_name="prepare_request") + def _prepare_request_for_processing( + self, state: RequestState, token_budget: int, request_ids_to_remove_from_waiting: set[str] + ): + """Prepare a request for processing in the current batch.""" + request_tokens = ( + state.remaining_prompt_ids if state.status == RequestStatus.SPLIT_PENDING_REMAINDER else state.prompt_ids + ) + if len(request_tokens) < token_budget: + # Can process the entire prompt/remainder + if state.status == RequestStatus.PENDING: + self.active_requests[state.request_id] = state + state.status = RequestStatus.PREFILLING + request_ids_to_remove_from_waiting.add(state.request_id) + elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER: + state.status = RequestStatus.PREFILLING + state.prompt_ids = state.remaining_prompt_ids + state.remaining_prompt_ids = [] + else: + # Need to split the request + if state.status == RequestStatus.PENDING: + self.active_requests[state.request_id] = state + state.status = RequestStatus.PREFILLING_SPLIT + request_ids_to_remove_from_waiting.add(state.request_id) + elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER: + state.status = RequestStatus.PREFILLING_SPLIT + state.remaining_prompt_ids = request_tokens[token_budget:] + state.prompt_ids = request_tokens[:token_budget] + + @traced + def add_waiting_request(self, state: RequestState): + """Add a request to the waiting list.""" + if self.retain_cache_on_finish and state.request_id in self.active_requests: + old_state = self.active_requests.pop(state.request_id) + state.prompt_ids = state.prompt_ids[len(old_state.full_prompt_ids) :] + state.allocated_blocks = old_state.allocated_blocks + state.position_offset = old_state.position_offset + self.waiting_requests[state.request_id] = state + self.waiting_requests_order.append(state.request_id) + + @traced + def schedule_batch(self, token_budget: int) -> list[RequestState]: + priority_states: list[RequestState] = [] + second_priority_states: list[RequestState] = [] + scheduled_requests = [] + + for state in self.active_requests.values(): + if state.status == RequestStatus.DECODING: + priority_states.append(state) + if state.status == RequestStatus.SPLIT_PENDING_REMAINDER: + second_priority_states.append(state) + + # Add waiting requests to second priority + for req_id in self.waiting_requests_order: + second_priority_states.append(self.waiting_requests[req_id]) + + candidates = priority_states + second_priority_states + request_ids_to_remove_from_waiting = set() + safety_margins = self.safety_margin * self.cache.num_blocks + + for state in candidates: + # If we are out the safety margin, we only accept decoding requests or the first prefill request + num_free_blocks = self.cache.get_num_free_blocks() + outside_safety_margin = num_free_blocks < safety_margins + if outside_safety_margin and scheduled_requests and state.status != RequestStatus.DECODING: + break + + self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting) + request_len = len(state.prompt_ids) + if not self._allocate_blocks_if_needed( + state, len(state.prompt_ids) + ): # don't schedule if we can't allocate blocks + if len(self.cache._free_blocks) == 0: + break + continue + + @traced + def _add_to_scheduled_requests(state: RequestState): + scheduled_requests.append(state) + + _add_to_scheduled_requests(state) + + token_budget -= request_len + + @traced + def _remove_from_waiting_requests(state: RequestState): + req_id = state.request_id + if req_id in self.waiting_requests: + del self.waiting_requests[req_id] + request_ids_to_remove_from_waiting.add(req_id) + + _remove_from_waiting_requests(state) + + if token_budget == 0: + break + + self.waiting_requests_order = deque( + [req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting] + ) + + return scheduled_requests + + @traced + def finish_request(self, request_id: str, evict_from_cache: bool = True): + if evict_from_cache: + self.cache.free_blocks(request_id) + if request_id in self.active_requests: + del self.active_requests[request_id] + + +@attach_tracer() +class PrefillFirstScheduler(Scheduler): + @traced + def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int): + # 1. we check that the occupancy is less than the requested length + # 2. we allocate enough blocks to cover the requested length + current_len = state.current_len() + occupancy = len(state.allocated_blocks) * self.cache.block_size - current_len + if occupancy < len_next_tokens or (len(state.allocated_blocks) == 0): + blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1 + allocated = self.cache.allocate_blocks(blocks_needed, state.request_id) + if not allocated: + return False + state.allocated_blocks.extend(allocated) + return True + + @traced(span_name="prepare_request") + def _prepare_request_for_processing( + self, state: RequestState, token_budget: int, request_ids_to_remove_from_waiting: set[str] + ): + """Prepare a request for processing in the current batch.""" + request_tokens = ( + state.remaining_prompt_ids if state.status == RequestStatus.SPLIT_PENDING_REMAINDER else state.prompt_ids + ) + if len(request_tokens) < token_budget: + # Can process the entire prompt/remainder + if state.status == RequestStatus.PENDING: + self.active_requests[state.request_id] = state + state.status = RequestStatus.PREFILLING + request_ids_to_remove_from_waiting.add(state.request_id) + elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER: + state.status = RequestStatus.PREFILLING + state.prompt_ids = state.remaining_prompt_ids + state.remaining_prompt_ids = [] + else: + # Need to split the request + if state.status == RequestStatus.PENDING: + self.active_requests[state.request_id] = state + state.status = RequestStatus.PREFILLING_SPLIT + request_ids_to_remove_from_waiting.add(state.request_id) + elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER: + state.status = RequestStatus.PREFILLING_SPLIT + state.remaining_prompt_ids = request_tokens[token_budget:] + state.prompt_ids = request_tokens[:token_budget] + + @traced + def add_waiting_request(self, state: RequestState): + """Add a request to the waiting list.""" + if self.retain_cache_on_finish and state.request_id in self.active_requests: + old_state = self.active_requests.pop(state.request_id) + state.prompt_ids = state.prompt_ids[len(old_state.full_prompt_ids) :] # XXX: check for indexing error? + state.allocated_blocks = old_state.allocated_blocks + state.position_offset = old_state.position_offset + self.waiting_requests[state.request_id] = state + self.waiting_requests_order.append(state.request_id) + + @traced + def schedule_batch(self, token_budget: int) -> list[RequestState]: + priority_states: list[RequestState] = [] + second_priority_states: list[RequestState] = [] + scheduled_requests = [] + + for state in self.active_requests.values(): + if state.status == RequestStatus.SPLIT_PENDING_REMAINDER: + priority_states.append(state) + elif state.status == RequestStatus.DECODING: + second_priority_states.append(state) + + for req_id in self.waiting_requests_order: + second_priority_states.append(self.waiting_requests[req_id]) + + candidates = priority_states + second_priority_states + + request_ids_to_remove_from_waiting = set() + + for state in candidates: + self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting) + request_len = len(state.prompt_ids) + if not self._allocate_blocks_if_needed( + state, len(state.prompt_ids) + ): # don't schedule if we can't allocate blocks + if len(self.cache._free_blocks) == 0: + break + continue + + @traced + def _add_to_scheduled_requests(state: RequestState): + scheduled_requests.append(state) + + _add_to_scheduled_requests(state) + + token_budget -= request_len + + @traced + def _remove_from_waiting_requests(state: RequestState): + req_id = state.request_id + if req_id in self.waiting_requests: + del self.waiting_requests[req_id] + request_ids_to_remove_from_waiting.add(req_id) + + _remove_from_waiting_requests(state) + + if token_budget == 0: + break + + self.waiting_requests_order = deque( + [req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting] + ) + + return scheduled_requests + + @traced + def finish_request(self, request_id: str, evict_from_cache: bool = True): + if evict_from_cache: + self.cache.free_blocks(request_id) + if request_id in self.active_requests: + del self.active_requests[request_id] + + +SCHEDULER_MAPPING = { + "fifo": FIFOScheduler, + "prefill_first": PrefillFirstScheduler, +}