From f101990e04fbe47d2b4be5c88d823084e1d8b59a Mon Sep 17 00:00:00 2001 From: Amey Agrawal Date: Sun, 26 May 2024 20:55:05 -0700 Subject: [PATCH] Port of PR 1813: Replace yapf with black (#10) --- .github/PULL_REQUEST_TEMPLATE.md | 46 ++ .github/workflows/lint.yml | 26 + .github/workflows/yapf.yml | 30 - .gitignore | 2 +- .pylintrc | 434 ------------ CONTRIBUTING.md | 6 +- Makefile | 24 + README.md | 65 +- config.yml | 38 -- environment-dev.yml | 12 + format.sh | 101 --- requirements-dev.txt | 17 +- sarathi/__init__.py | 4 +- sarathi/benchmark/benchmark_runner.py | 142 ++-- .../capacity_search/capacity_search.py | 66 +- .../capacity_search/config/__init__.py | 4 +- .../capacity_search/config/config.py | 28 +- sarathi/benchmark/capacity_search/main.py | 38 +- .../benchmark/capacity_search/ray_utils.py | 45 +- .../capacity_search/search_manager.py | 17 +- sarathi/benchmark/config/config.py | 31 +- sarathi/benchmark/constants.py | 2 +- sarathi/benchmark/main.py | 8 +- .../benchmark/request_generator/__init__.py | 3 +- .../fixed_request_length_generator.py | 9 +- .../gamma_request_interval_generator.py | 3 +- .../poisson_request_interval_generator.py | 3 +- .../request_generator_registry.py | 16 +- .../request_interval_generator_registry.py | 32 +- .../request_length_generator_registry.py | 34 +- .../static_request_interval_generator.py | 3 +- .../synthetic_request_generator.py | 40 +- .../trace_replay_request_generator.py | 60 +- .../trace_request_interval_generator.py | 36 +- .../trace_request_length_generator.py | 64 +- .../uniform_request_length_generator.py | 14 +- .../zipf_request_length_generator.py | 6 +- sarathi/benchmark/types/__init__.py | 15 +- .../types/request_length_generator_type.py | 2 +- sarathi/benchmark/utils/zipf_generator.py | 15 +- sarathi/config.py | 120 ++-- .../base_block_space_manager.py | 3 +- .../block_space_manager_registry.py | 33 +- .../faster_transformer_block_space_manager.py | 4 +- .../orca_block_space_manager.py | 4 +- .../sarathi_block_space_manager.py | 4 +- .../simple_chunking_block_space_manager.py | 4 +- .../vllm_block_space_manager.py | 4 +- sarathi/core/datatypes/block.py | 11 +- sarathi/core/datatypes/request_output.py | 14 +- sarathi/core/datatypes/sampling_params.py | 26 +- sarathi/core/datatypes/scheduler_output.py | 24 +- sarathi/core/datatypes/sequence.py | 47 +- sarathi/core/datatypes/sequence_state.py | 93 ++- sarathi/core/policy.py | 2 +- sarathi/core/scheduler/base_scheduler.py | 14 +- .../scheduler/faster_transformer_scheduler.py | 14 +- sarathi/core/scheduler/orca_scheduler.py | 14 +- sarathi/core/scheduler/sarathi_scheduler.py | 60 +- sarathi/core/scheduler/scheduler_registry.py | 14 +- .../scheduler/simple_chunking_scheduler.py | 33 +- sarathi/core/scheduler/vllm_scheduler.py | 26 +- .../sequence_manager/base_sequence_manager.py | 32 +- .../engine_sequence_manager.py | 27 +- .../worker_sequence_manager.py | 10 +- sarathi/engine/arg_utils.py | 104 +-- sarathi/engine/base_llm_engine.py | 108 +-- .../engine/pipeline_parallel_llm_engine.py | 58 +- sarathi/engine/ray_utils.py | 19 +- sarathi/metrics/cdf_sketch.py | 76 +-- sarathi/metrics/constants.py | 21 +- sarathi/metrics/cpu_timer.py | 10 +- sarathi/metrics/cuda_timer.py | 22 +- sarathi/metrics/data_series.py | 92 ++- sarathi/metrics/metrics_store.py | 615 ++++++++++-------- sarathi/model_executor/attention/__init__.py | 16 +- .../attention/base_attention_wrapper.py | 11 +- .../attention/flash_attention_wrapper.py | 102 +-- .../attention/flashinfer_attention_wrapper.py | 72 +- .../flashinfer_unpaged_attention_wrapper.py | 145 +++-- sarathi/model_executor/attention/kv_buffer.py | 6 +- .../attention/no_op_attention_wrapper.py | 8 +- sarathi/model_executor/layers/activation.py | 1 + sarathi/model_executor/layers/layernorm.py | 4 +- .../model_executor/layers/rotary_embedding.py | 174 ++--- sarathi/model_executor/layers/sampler.py | 57 +- sarathi/model_executor/model_loader.py | 12 +- sarathi/model_executor/model_runner.py | 83 ++- sarathi/model_executor/models/__init__.py | 2 +- sarathi/model_executor/models/falcon.py | 209 +++--- sarathi/model_executor/models/internlm.py | 137 ++-- sarathi/model_executor/models/llama.py | 178 ++--- sarathi/model_executor/models/mistral.py | 173 ++--- sarathi/model_executor/models/qwen.py | 97 ++- sarathi/model_executor/models/yi.py | 149 +++-- .../parallel_utils/parallel_state.py | 156 +++-- .../pipeline_parallel/mappings.py | 4 +- .../tensor_parallel/__init__.py | 22 +- .../parallel_utils/tensor_parallel/layers.py | 247 ++++--- .../tensor_parallel/mappings.py | 44 +- .../parallel_utils/tensor_parallel/random.py | 24 +- .../parallel_utils/tensor_parallel/utils.py | 30 +- sarathi/model_executor/utils.py | 9 +- sarathi/model_executor/weight_utils.py | 48 +- sarathi/transformers_utils/config.py | 18 +- .../transformers_utils/configs/__init__.py | 2 +- sarathi/transformers_utils/configs/falcon.py | 4 +- sarathi/transformers_utils/configs/yi.py | 6 +- sarathi/transformers_utils/tokenizer.py | 47 +- sarathi/utils/__init__.py | 2 +- sarathi/utils/singleton.py | 5 +- sarathi/utils/threading_utils.py | 9 +- sarathi/worker/base_worker.py | 92 ++- sarathi/worker/cache_engine.py | 6 +- sarathi/worker/pipeline_parallel_worker.py | 44 +- 115 files changed, 2943 insertions(+), 2815 deletions(-) create mode 100644 .github/PULL_REQUEST_TEMPLATE.md create mode 100644 .github/workflows/lint.yml delete mode 100644 .github/workflows/yapf.yml delete mode 100644 .pylintrc create mode 100644 Makefile delete mode 100644 config.yml create mode 100644 environment-dev.yml delete mode 100755 format.sh diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..84ef7cb --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,46 @@ +FILL IN THE PR DESCRIPTION HERE + +FIX #xxxx (*link existing issues this PR will resolve*) + +**BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE** + +--- + +
+ + PR Checklist (Click to Expand) + +

Thank you for your contribution to Vidur! Before submitting the pull request, please ensure the PR meets the following criteria. This helps Vidur maintain the code quality and improve the efficiency of the review process.

+ +

PR Title and Classification

+

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

+ +

Note: If the PR spans more than one category, please include all relevant prefixes.

+ +

Code Quality

+ +

The PR need to meet the following code quality standards:

+ + + +

Notes for Large Changes

+

Please keep the changes as concise as possible. For major architectural changes (>500 LOC), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

+ +

Thank You

+ +

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to Vidur. Your contributions make Vidur a great tool for everyone!

+ + +
\ No newline at end of file diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..b6629cd --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,26 @@ +name: "Run linter" +on: + push: + branches: [main] + pull_request: + branches: [main] +permissions: + contents: read + packages: write +defaults: + run: + shell: bash -l {0} +jobs: + sanity_check: + runs-on: "ubuntu-latest" + steps: + - name: "Checkout Repository" + uses: actions/checkout@v3 + - name: Install Conda environment from environment-dev.yml + uses: mamba-org/setup-micromamba@v1 + with: + environment-file: environment-dev.yml + - name: "Run black lint" + run: make lint/black + - name: "Run isort check" + run: make lint/isort diff --git a/.github/workflows/yapf.yml b/.github/workflows/yapf.yml deleted file mode 100644 index 87555b9..0000000 --- a/.github/workflows/yapf.yml +++ /dev/null @@ -1,30 +0,0 @@ -name: yapf - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - push: - branches: - - main - pull_request: - branches: - - main -jobs: - yapf: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.10"] - steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install -r requirements-dev.txt - - name: Running yapf - run: | - yapf --diff --recursive sarathi --exclude 'sarathi/model_executor/parallel_utils/**' diff --git a/.gitignore b/.gitignore index ad0cd0c..0e3b8ef 100644 --- a/.gitignore +++ b/.gitignore @@ -122,7 +122,7 @@ celerybeat.pid # Environments .env .venv -env* +env/ env_sarathi/ env_flashinfer/ env_flashinfer_2/ diff --git a/.pylintrc b/.pylintrc deleted file mode 100644 index 1712c0e..0000000 --- a/.pylintrc +++ /dev/null @@ -1,434 +0,0 @@ -# This Pylint rcfile contains a best-effort configuration to uphold the -# best-practices and style described in the Google Python style guide: -# https://google.github.io/styleguide/pyguide.html -# -# Its canonical open-source location is: -# https://google.github.io/styleguide/pylintrc - -[MASTER] - -# Files or directories to be skipped. They should be base names, not paths. -ignore=parallel_utils - -# Files or directories matching the regex patterns are skipped. The regex -# matches against base names, not paths. -ignore-patterns= - -# Pickle collected data for later comparisons. -persistent=no - -# List of plugins (as comma separated values of python modules names) to load, -# usually to register additional checkers. -load-plugins= - -# Use multiple processes to speed up Pylint. -jobs=4 - -# Allow loading of arbitrary C extensions. Extensions are imported into the -# active Python interpreter and may run arbitrary code. -unsafe-load-any-extension=no - - -[MESSAGES CONTROL] - -# Only show warnings with the listed confidence levels. Leave empty to show -# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED -confidence= - -# Enable the message, report, category or checker with the given id(s). You can -# either give multiple identifier separated by comma (,) or put this option -# multiple time (only on the command line, not in the configuration file where -# it should appear only once). See also the "--disable" option for examples. -#enable= - -# Disable the message, report, category or checker with the given id(s). You -# can either give multiple identifiers separated by comma (,) or put this -# option multiple times (only on the command line, not in the configuration -# file where it should appear only once).You can also use "--disable=all" to -# disable everything first and then reenable specific checks. For example, if -# you want to run only the similarities checker, you can use "--disable=all -# --enable=similarities". If you want to run only the classes checker, but have -# no Warning level messages displayed, use"--disable=all --enable=classes -# --disable=W" -disable=abstract-method, - apply-builtin, - arguments-differ, - attribute-defined-outside-init, - backtick, - bad-option-value, - basestring-builtin, - buffer-builtin, - c-extension-no-member, - consider-using-enumerate, - cmp-builtin, - cmp-method, - coerce-builtin, - coerce-method, - delslice-method, - div-method, - duplicate-code, - eq-without-hash, - execfile-builtin, - file-builtin, - filter-builtin-not-iterating, - fixme, - getslice-method, - global-statement, - hex-method, - idiv-method, - implicit-str-concat-in-sequence, - import-error, - import-self, - import-star-module-level, - inconsistent-return-statements, - input-builtin, - intern-builtin, - invalid-str-codec, - locally-disabled, - logging-fstring-interpolation, # added by Sarathi - logging-not-lazy, # added by Sarathi - long-builtin, - long-suffix, - map-builtin-not-iterating, - misplaced-comparison-constant, - missing-class-docstring, # TODO (Sarathi): enable - missing-function-docstring, - missing-module-docstring, # TODO (Sarathi): enable - metaclass-assignment, - next-method-called, - next-method-defined, - no-absolute-import, - no-else-break, - no-else-continue, - no-else-raise, - no-else-return, - no-init, # added - no-member, - no-name-in-module, - no-self-use, - nonzero-method, - oct-method, - old-division, - old-ne-operator, - old-octal-literal, - old-raise-syntax, - parameter-unpacking, - print-statement, - raising-string, - range-builtin-not-iterating, - raw_input-builtin, - rdiv-method, - reduce-builtin, - relative-import, - reload-builtin, - round-builtin, - setslice-method, - signature-differs, - standarderror-builtin, - suppressed-message, - sys-max-int, - too-few-public-methods, - too-many-ancestors, - too-many-arguments, - too-many-boolean-expressions, - too-many-branches, - too-many-instance-attributes, - too-many-locals, - too-many-nested-blocks, - too-many-public-methods, - too-many-return-statements, - too-many-statements, - trailing-newlines, - unichr-builtin, - unicode-builtin, - unnecessary-pass, - unpacking-in-except, - unspecified-encoding, - useless-else-on-loop, - useless-object-inheritance, - useless-suppression, - using-cmp-argument, - wrong-import-order, - xrange-builtin, - zip-builtin-not-iterating, - - -[REPORTS] - -# Set the output format. Available formats are text, parseable, colorized, msvs -# (visual studio) and html. You can also give a reporter class, eg -# mypackage.mymodule.MyReporterClass. -output-format=text - -# Tells whether to display a full report or only the messages -reports=no - -# Python expression which should return a note less than 10 (10 is the highest -# note). You have access to the variables errors warning, statement which -# respectively contain the number of errors / warnings messages and the total -# number of statements analyzed. This is used by the global evaluation report -# (RP0004). -evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) - -# Template used to display messages. This is a python new-style format string -# used to format the message information. See doc for all details -#msg-template= - - -[BASIC] - -# Good variable names which should always be accepted, separated by a comma -good-names=main,_ - -# Bad variable names which should always be refused, separated by a comma -bad-names= - -# Colon-delimited sets of names that determine each other's naming style when -# the name regexes allow several styles. -name-group= - -# Include a hint for the correct naming format with invalid-name -include-naming-hint=no - -# List of decorators that produce properties, such as abc.abstractproperty. Add -# to this list to register other decorators that produce valid properties. -property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl - -# Regular expression matching correct function names -function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ - -# Regular expression matching correct variable names -variable-rgx=^[a-z][a-z0-9_]*$ - -# Regular expression matching correct constant names -const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ - -# Regular expression matching correct attribute names -attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ - -# Regular expression matching correct argument names -argument-rgx=^[a-z][a-z0-9_]*$ - -# Regular expression matching correct class attribute names -class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ - -# Regular expression matching correct inline iteration names -inlinevar-rgx=^[a-z][a-z0-9_]*$ - -# Regular expression matching correct class names -class-rgx=^_?[A-Z][a-zA-Z0-9]*$ - -# Regular expression matching correct module names -module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ - -# Regular expression matching correct method names -method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ - -# Regular expression which should only match function or class names that do -# not require a docstring. -no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ - -# Minimum line length for functions/classes that require docstrings, shorter -# ones are exempt. -docstring-min-length=10 - - -[TYPECHECK] - -# List of decorators that produce context managers, such as -# contextlib.contextmanager. Add to this list to register other decorators that -# produce valid context managers. -contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager - -# Tells whether missing members accessed in mixin class should be ignored. A -# mixin class is detected if its name ends with "mixin" (case insensitive). -ignore-mixin-members=yes - -# List of module names for which member attributes should not be checked -# (useful for modules/projects where namespaces are manipulated during runtime -# and thus existing member attributes cannot be deduced by static analysis. It -# supports qualified module names, as well as Unix pattern matching. -ignored-modules= - -# List of class names for which member attributes should not be checked (useful -# for classes with dynamically set attributes). This supports the use of -# qualified names. -ignored-classes=optparse.Values,thread._local,_thread._local - -# List of members which are set dynamically and missed by pylint inference -# system, and so shouldn't trigger E1101 when accessed. Python regular -# expressions are accepted. -generated-members= - - -[FORMAT] - -# Maximum number of characters on a single line. -max-line-length=80 - -# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt -# lines made too long by directives to pytype. - -# Regexp for a line that is allowed to be longer than the limit. -ignore-long-lines=(?x)( - ^\s*(\#\ )??$| - ^\s*(from\s+\S+\s+)?import\s+.+$) - -# Allow the body of an if to be on the same line as the test if there is no -# else. -single-line-if-stmt=yes - -# Maximum number of lines in a module -max-module-lines=99999 - -# String used as indentation unit. The internal Google style guide mandates 2 -# spaces. Google's externaly-published style guide says 4, consistent with -# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google -# projects (like TensorFlow). -indent-string=' ' - -# Number of spaces of indent required inside a hanging or continued line. -indent-after-paren=4 - -# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. -expected-line-ending-format= - - -[MISCELLANEOUS] - -# List of note tags to take in consideration, separated by a comma. -notes=TODO - - -[STRING] - -# This flag controls whether inconsistent-quotes generates a warning when the -# character used as a quote delimiter is used inconsistently within a module. -check-quote-consistency=yes - - -[VARIABLES] - -# Tells whether we should check for unused import in __init__ files. -init-import=no - -# A regular expression matching the name of dummy variables (i.e. expectedly -# not used). -dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) - -# List of additional names supposed to be defined in builtins. Remember that -# you should avoid to define new builtins when possible. -additional-builtins= - -# List of strings which can identify a callback function by name. A callback -# name must start or end with one of those strings. -callbacks=cb_,_cb - -# List of qualified module names which can have objects that can redefine -# builtins. -redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools - - -[LOGGING] - -# Logging modules to check that the string format arguments are in logging -# function parameter format -logging-modules=logging,absl.logging,tensorflow.io.logging - - -[SIMILARITIES] - -# Minimum lines number of a similarity. -min-similarity-lines=4 - -# Ignore comments when computing similarities. -ignore-comments=yes - -# Ignore docstrings when computing similarities. -ignore-docstrings=yes - -# Ignore imports when computing similarities. -ignore-imports=no - - -[SPELLING] - -# Spelling dictionary name. Available dictionaries: none. To make it working -# install python-enchant package. -spelling-dict= - -# List of comma separated words that should not be checked. -spelling-ignore-words= - -# A path to a file that contains private dictionary; one word per line. -spelling-private-dict-file= - -# Tells whether to store unknown words to indicated private dictionary in -# --spelling-private-dict-file option instead of raising a message. -spelling-store-unknown-words=no - - -[IMPORTS] - -# Deprecated modules which should not be used, separated by a comma -deprecated-modules=regsub, - TERMIOS, - Bastion, - rexec, - sets - -# Create a graph of every (i.e. internal and external) dependencies in the -# given file (report RP0402 must not be disabled) -import-graph= - -# Create a graph of external dependencies in the given file (report RP0402 must -# not be disabled) -ext-import-graph= - -# Create a graph of internal dependencies in the given file (report RP0402 must -# not be disabled) -int-import-graph= - -# Force import order to recognize a module as part of the standard -# compatibility libraries. -known-standard-library= - -# Force import order to recognize a module as part of a third party library. -known-third-party=enchant, absl - -# Analyse import fallback blocks. This can be used to support both Python 2 and -# 3 compatible code, which means that the block might have code that exists -# only in one or another interpreter, leading to false positives when analysed. -analyse-fallback-blocks=no - - -[CLASSES] - -# List of method names used to declare (i.e. assign) instance attributes. -defining-attr-methods=__init__, - __new__, - setUp - -# List of member names, which should be excluded from the protected access -# warning. -exclude-protected=_asdict, - _fields, - _replace, - _source, - _make - -# List of valid names for the first argument in a class method. -valid-classmethod-first-arg=cls, - class_ - -# List of valid names for the first argument in a metaclass class method. -valid-metaclass-classmethod-first-arg=mcs - - -[EXCEPTIONS] - -# Exceptions that will emit a warning when being caught. Defaults to -# "Exception" -overgeneral-exceptions=StandardError, - Exception, - BaseException diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4a71d72..1801a30 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -38,7 +38,11 @@ If not, please file a new issue, providing as much relevant information as possi In general, we adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html). -We include a formatting script [`format.sh`](./format.sh) to format the code. +To format the code run, + +```sh +make format +``` ### Pull Requests diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..a12f980 --- /dev/null +++ b/Makefile @@ -0,0 +1,24 @@ +.PHONY: help lint lint/flake8 lint/black lint/isort format format/black format/autopep8 format/isort +.DEFAULT_GOAL := help + +lint/flake8: ## check style with flake8 + flake8 sarathi + +lint/black: ## check style with black + black --check sarathi + +lint/isort: ## check style with isort + isort --check-only --profile black sarathi + +lint: lint/black lint/isort ## check style + +format/black: ## format code with black + black sarathi + +format/autopep8: ## format code with autopep8 + autopep8 --in-place --recursive sarathi/ + +format/isort: ## format code with isort + isort --profile black sarathi + +format: format/isort format/black ## format code diff --git a/README.md b/README.md index f3943ef..e45f4a5 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,6 @@ # Sarathi-Serve -This repository contains the code for [Taming Throughput-Latency Tradeoff in LLM Inference with Sarathi-Serve](https://arxiv.org/abs/2403.02310). -This codebase also serves as the baseline for fidelity tests for the LLM inference system simulator [Vidur](https://github.com/microsoft/vidur). +This is the official OSDI'24 artifact submission for paper #444, "Taming Throughput-Latency Tradeoff in LLM Inference with Sarathi-Serve”. ## Setup @@ -12,7 +11,7 @@ Sarathi-Serve has been tested with CUDA 12.1 on A100 and A40 GPUs. ### Clone repository ```sh -git clone https://github.com/microsoft/sarathi.git +git clone https://msri@dev.azure.com/msri/AI-Infrastructure/_git/llm-batching ``` ### Create mamba environment @@ -30,47 +29,21 @@ Create a Python 3.10 environment, mamba create -p ./env python=3.10 ``` -### Install Dev Dependencies - -```sh -pip install -r requirements-dev.txt -``` - -### Install PyTorch - -```sh -pip install torch==2.2.2 --index-url https://download.pytorch.org/whl/cu121 -``` - -### Update NCCL - -NCCL 2.19 which ships with torch 2.2 by default is buggy, update NCCL to the latest (2.21). -This setup may need to be repeated after the `Install Sarathi-Serve` step is done. - -```sh -pip install -U nvidia-nccl-cu12 -``` - -### Install FlashAttention - -```sh -pip install ninja packaging -git submodule update --init --recursive -cd third_party/flash-attention -python setup.py install -``` - ### Install Sarathi-Serve ```sh pip install -e . --extra-index-url https://flashinfer.ai/whl/cu121/torch2.3/ ``` +## Reproducing Results + +Refer to readmes in individual folders corresponding to each figure in `osdi-experiments`. + ## Citation If you use our work, please consider citing our paper: -```latex +``` @article{agrawal2024taming, title={Taming Throughput-Latency Tradeoff in LLM Inference with Sarathi-Serve}, author={Agrawal, Amey and Kedia, Nitin and Panwar, Ashish and Mohan, Jayashree and Kwatra, Nipun and Gulavani, Bhargav S and Tumanov, Alexey and Ramjee, Ramachandran}, @@ -82,27 +55,3 @@ If you use our work, please consider citing our paper: ## Acknowledgment This repository originally started as a fork of the [vLLM project](https://vllm-project.github.io/). Sarathi-Serve is a research prototype and does not have complete feature parity with open-source vLLM. We have only retained the most critical features and adopted the codebase for faster research iterations. - -## Contributing - -Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved. - -This project welcomes contributions and suggestions. Most contributions require you to agree to a -Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us -the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. - -When you submit a pull request, a CLA bot will automatically determine whether you need to provide -a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions -provided by the bot. You will only need to do this once across all repos using our CLA. - -This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). -For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or -contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. - -## Trademarks - -This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft -trademarks or logos is subject to and must follow -[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). -Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. -Any use of third-party trademarks or logos are subject to those third-party's policies. diff --git a/config.yml b/config.yml deleted file mode 100644 index abc02f3..0000000 --- a/config.yml +++ /dev/null @@ -1,38 +0,0 @@ -model: meta-llama/Llama-2-7b-hf -replica_id: 0 -replica_resource_mapping: [] -tokenizer: meta-llama/Llama-2-7b-hf -tokenizer_mode: auto -trust_remote_code: true -download_dir: null -load_format: auto -dtype: auto -seed: 0 -pipeline_parallel_size: 1 -tensor_parallel_size: 8 -block_size: 16 -gpu_memory_utilization: 0.85 -revision: null -scheduler_type: sarathi -max_model_len: 4096 -max_num_seqs: 128 -max_num_batched_tokens: null -chunk_size: 100 -enable_dynamic_chunking_schedule: false -low_chunk_size: null -high_chunk_size: null -chunk_schedule_max_tokens: null -chunk_schedule_stages: null -write_metrics: true -output_dir: . -wandb_project: null -wandb_sweep_id: null -wandb_run_id: null -wandb_group: null -wandb_run_name: null -enable_op_level_metrics: false -enable_cpu_op_level_metrics: false -enable_chrome_trace: false -enable_request_outputs: false -keep_individual_batch_metrics: false -attention_backend: FLASHINFER_UNPAGED diff --git a/environment-dev.yml b/environment-dev.yml new file mode 100644 index 0000000..1497fd6 --- /dev/null +++ b/environment-dev.yml @@ -0,0 +1,12 @@ +name: sarathi-serve +channels: + - conda-forge +dependencies: + - python=3.10 + - setuptools + - pip + - make + - black + - isort + - flake8 + - autopep8 diff --git a/format.sh b/format.sh deleted file mode 100755 index 43e7523..0000000 --- a/format.sh +++ /dev/null @@ -1,101 +0,0 @@ -#!/usr/bin/env bash -# YAPF formatter, adapted from ray and skypilot. -# -# Usage: -# # Do work and commit your work. - -# # Format files that differ from origin/main. -# bash format.sh - -# # Commit changed files with message 'Run yapf and pylint' -# -# -# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase. -# You are encouraged to run this locally before pushing changes for review. - -# Cause the script to exit if a single command fails -set -eo pipefail - -# this stops git rev-parse from failing if we run this from the .git directory -builtin cd "$(dirname "${BASH_SOURCE:-$0}")" -ROOT="$(git rev-parse --show-toplevel)" -builtin cd "$ROOT" || exit 1 - -YAPF_VERSION=$(yapf --version | awk '{print $2}') -PYLINT_VERSION=$(pylint --version | head -n 1 | awk '{print $2}') - -# # params: tool name, tool version, required version -tool_version_check() { - if [[ $2 != $3 ]]; then - echo "Wrong $1 version installed: $3 is required, not $2." - exit 1 - fi -} - -tool_version_check "yapf" $YAPF_VERSION "$(grep yapf requirements-dev.txt | cut -d'=' -f3)" -tool_version_check "pylint" $PYLINT_VERSION "$(grep "pylint==" requirements-dev.txt | cut -d'=' -f3)" - -YAPF_FLAGS=( - '--recursive' - '--parallel' -) - -YAPF_EXCLUDES=( - '--exclude' 'build/**' - '--exclude' 'sarathi/model_executor/parallel_utils/**' -) - -# Format specified files -format() { - yapf --in-place "${YAPF_FLAGS[@]}" "$@" -} - -# Format files that differ from main branch. Ignores dirs that are not slated -# for autoformat yet. -format_changed() { - # The `if` guard ensures that the list of filenames is not empty, which - # could cause yapf to receive 0 positional arguments, making it hang - # waiting for STDIN. - # - # `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that - # exist on both branches. - MERGEBASE="$(git merge-base origin/main HEAD)" - - if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then - git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 \ - yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}" - fi - -} - -# Format all files -format_all() { - yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" sarathi -} - -## This flag formats individual files. --files *must* be the first command line -## arg to use this option. -if [[ "$1" == '--files' ]]; then - format "${@:2}" - # If `--all` is passed, then any further arguments are ignored and the - # entire python directory is formatted. -elif [[ "$1" == '--all' ]]; then - format_all -else - # Format only the files that changed in last commit. - format_changed -fi -echo 'Sarathi yapf: Done' - -# Run Pylint -echo 'Sarathi Pylint:' -pylint sarathi - -if ! git diff --quiet &>/dev/null; then - echo 'Reformatted files. Please review and stage the changes.' - echo 'Changes not staged for commit:' - echo - git --no-pager diff --name-only - - exit 1 -fi diff --git a/requirements-dev.txt b/requirements-dev.txt index c03b2e4..27c75dd 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,11 +1,6 @@ -# formatting -yapf==0.40.2 -pylint==2.8.2 - -# type checking -types-PyYAML -types-requests -types-setuptools - -# notebooks -nbdime +pylint +black +autopep8 +isort +flake8 +make diff --git a/sarathi/__init__.py b/sarathi/__init__.py index 2b2fc64..76b3fe9 100644 --- a/sarathi/__init__.py +++ b/sarathi/__init__.py @@ -1,9 +1,9 @@ """Sarathi: a high-throughput and memory-efficient inference engine for LLMs""" -from sarathi.engine.arg_utils import EngineArgs -from sarathi.engine.llm_engine import LLMEngine from sarathi.core.datatypes.request_output import RequestOutput from sarathi.core.datatypes.sampling_params import SamplingParams +from sarathi.engine.arg_utils import EngineArgs +from sarathi.engine.llm_engine import LLMEngine __version__ = "0.1.7" diff --git a/sarathi/benchmark/benchmark_runner.py b/sarathi/benchmark/benchmark_runner.py index 783ae14..1508e6d 100644 --- a/sarathi/benchmark/benchmark_runner.py +++ b/sarathi/benchmark/benchmark_runner.py @@ -1,32 +1,33 @@ -import time +import json import logging import os -import json +import time import ray import wandb from tqdm import tqdm +from sarathi import LLMEngine, SamplingParams from sarathi.benchmark.config import Config -from sarathi.benchmark.request_generator import RequestGeneratorRegistry from sarathi.benchmark.entities import Request -from sarathi import LLMEngine -from sarathi import SamplingParams +from sarathi.benchmark.request_generator import RequestGeneratorRegistry +from sarathi.benchmark.types import ReplicaResourceMapping, ResourceMapping +from sarathi.benchmark.utils.random import set_seeds from sarathi.config import MetricsConfig from sarathi.metrics.metrics_store import MetricsStore -from sarathi.benchmark.utils.random import set_seeds from sarathi.utils import get_ip -from sarathi.benchmark.types import ReplicaResourceMapping, ResourceMapping logger = logging.getLogger(__name__) class BenchmarkRunner: - def __init__(self, - replica_id: int, - config: Config, - replica_resource_mapping: ResourceMapping = []) -> None: + def __init__( + self, + replica_id: int, + config: Config, + replica_resource_mapping: ResourceMapping = [], + ) -> None: self._replica_id = replica_id self._config = config self._num_replicas = self._config.cluster_num_replicas @@ -40,14 +41,15 @@ def __init__(self, set_seeds(config.seed) request_generator = RequestGeneratorRegistry.get_from_str( - self._config.request_generator_provider, self._config) + self._config.request_generator_provider, self._config + ) self._requests = request_generator.generate() # select every nth request for this replica # e.g. if there are 4 replicas, and this is the 2nd replica, then # we will select the 2nd, 6th, 10th, ... requests # round robin scheduling - self._requests = self._requests[self._replica_id::self._num_replicas] + self._requests = self._requests[self._replica_id :: self._num_replicas] if self._num_replicas == 1: wandb_project = self._config.metrics_store_wandb_project @@ -85,17 +87,13 @@ def __init__(self, max_num_seqs=self._config.replica_scheduler_max_batch_size, # sarathi scheduler config chunk_size=chunk_size, - enable_dynamic_chunking_schedule=self._config. - sarathi_scheduler_enable_dynamic_chunking_schedule, + enable_dynamic_chunking_schedule=self._config.sarathi_scheduler_enable_dynamic_chunking_schedule, low_chunk_size=self._config.sarathi_scheduler_low_chunk_size, high_chunk_size=self._config.sarathi_scheduler_high_chunk_size, - chunk_schedule_max_tokens=self._config. - sarathi_scheduler_chunk_schedule_max_tokens, - chunk_schedule_stages=self._config. - sarathi_scheduler_chunk_schedule_stages, + chunk_schedule_max_tokens=self._config.sarathi_scheduler_chunk_schedule_max_tokens, + chunk_schedule_stages=self._config.sarathi_scheduler_chunk_schedule_stages, # vllm scheduler config - max_num_batched_tokens=self._config. - vllm_scheduler_max_tokens_in_batch, + max_num_batched_tokens=self._config.vllm_scheduler_max_tokens_in_batch, # wandb config write_metrics=self._config.write_metrics, enable_chrome_trace=self._config.write_chrome_trace, @@ -105,20 +103,17 @@ def __init__(self, wandb_sweep_id=self._config.metrics_store_wandb_sweep_id, wandb_run_id=self._config.metrics_store_wandb_run_id, # metrics config - enable_op_level_metrics=self._config. - metrics_store_enable_op_level_metrics, - enable_cpu_op_level_metrics=self._config. - metrics_store_enable_cpu_op_level_metrics, - enable_request_outputs=self._config. - metrics_store_enable_request_outputs, - keep_individual_batch_metrics=self._config. - metrics_store_keep_individual_batch_metrics, + enable_op_level_metrics=self._config.metrics_store_enable_op_level_metrics, + enable_cpu_op_level_metrics=self._config.metrics_store_enable_cpu_op_level_metrics, + enable_request_outputs=self._config.metrics_store_enable_request_outputs, + keep_individual_batch_metrics=self._config.metrics_store_keep_individual_batch_metrics, # engine config trust_remote_code=True, ) - def _get_input_params(self, request: Request, - first_request_time: float) -> SamplingParams: + def _get_input_params( + self, request: Request, first_request_time: float + ) -> SamplingParams: sampling_params = SamplingParams( ignore_eos=True, max_tokens=request.num_decode_tokens, @@ -137,7 +132,8 @@ def _get_input_params(self, request: Request, def warmup(self) -> None: # warmup the engine self._llm_engine.add_request( - **self._get_input_params(self._requests[0], time.monotonic())) + **self._get_input_params(self._requests[0], time.monotonic()) + ) is_completed = False while not is_completed: @@ -152,8 +148,10 @@ def _run(self) -> None: num_processed_requests = 0 num_steps = 0 - pbar = tqdm(total=len(self._requests), - desc=f"Replica {self._replica_id} processed requests") + pbar = tqdm( + total=len(self._requests), + desc=f"Replica {self._replica_id} processed requests", + ) start_time = time.monotonic() # Run the engine. @@ -185,7 +183,8 @@ def _add_requests(self) -> None: while index < len(self._requests): request = self._requests[index] self._llm_engine.add_request( - **self._get_input_params(request, first_request_time)) + **self._get_input_params(request, first_request_time) + ) index += 1 def run(self) -> None: @@ -208,13 +207,13 @@ def __init__(self, config: Config) -> None: if self._is_multi_replica: self._validate_cluster_resources() self._runners = self._create_runners() - self._aggregate_metric_store = self._create_aggregate_metric_store( - ) + self._aggregate_metric_store = self._create_aggregate_metric_store() else: replica_resource_mapping = self._get_replica_resource_mapping() assert len(replica_resource_mapping) == 1 - self._runner = BenchmarkRunner(0, self._config, - replica_resource_mapping["0"]) + self._runner = BenchmarkRunner( + 0, self._config, replica_resource_mapping["0"] + ) if wandb.run is not None: wandb.config.update(self._config.__dict__) @@ -227,21 +226,21 @@ def _validate_cluster_resources(self): available_resources = ray.available_resources() - assert available_resources["GPU"] >= num_gpus_required, \ - f"Insufficient GPUs. Required: {num_gpus_required}, Available: {available_resources['GPU']}" + assert ( + available_resources["GPU"] >= num_gpus_required + ), f"Insufficient GPUs. Required: {num_gpus_required}, Available: {available_resources['GPU']}" def _get_replica_resource_mapping(self) -> ReplicaResourceMapping: if self._config.replica_resource_mapping: - replica_resource_mapping = json.loads( - self._config.replica_resource_mapping) - logger.info( - f"Replica resource mapping: {replica_resource_mapping}") + replica_resource_mapping = json.loads(self._config.replica_resource_mapping) + logger.info(f"Replica resource mapping: {replica_resource_mapping}") return replica_resource_mapping cluster_resources_keys = list(ray.available_resources().keys()) num_gpus = ray.available_resources()["GPU"] ip_addresses = [ - x for x in cluster_resources_keys + x + for x in cluster_resources_keys if x.startswith("node:") and x != "node:__internal_head__" ] @@ -253,14 +252,19 @@ def _get_replica_resource_mapping(self) -> ReplicaResourceMapping: num_nodes = len(ip_addresses) assert num_nodes > 0, "No nodes found in the cluster" assert num_gpus > 0, "No GPUs found in the cluster" - assert num_gpus % num_nodes == 0, \ - f"Number of GPUs ({num_gpus}) is not a multiple of number of nodes ({num_nodes})" + assert ( + num_gpus % num_nodes == 0 + ), f"Number of GPUs ({num_gpus}) is not a multiple of number of nodes ({num_nodes})" num_gpus_per_node = int(num_gpus // num_nodes) num_replicas = self._config.cluster_num_replicas - num_gpus_per_replica = self._config.model_tensor_parallel_degree * self._config.model_pipeline_parallel_degree + num_gpus_per_replica = ( + self._config.model_tensor_parallel_degree + * self._config.model_pipeline_parallel_degree + ) - assert num_gpus >= num_replicas * num_gpus_per_replica, \ - f"Insufficient GPUs. Required: {num_replicas * num_gpus_per_replica}, Available: {num_gpus}" + assert ( + num_gpus >= num_replicas * num_gpus_per_replica + ), f"Insufficient GPUs. Required: {num_replicas * num_gpus_per_replica}, Available: {num_gpus}" replica_resource_mapping = {} @@ -272,15 +276,17 @@ def _get_replica_resource_mapping(self) -> ReplicaResourceMapping: for replica_id in range(num_replicas): replica_resource_mapping[str(replica_id)] = [] for _ in range(num_gpus_per_replica): - replica_resource_mapping[str(replica_id)].append( - available_gpus.pop(0)) + replica_resource_mapping[str(replica_id)].append(available_gpus.pop(0)) logger.info(f"Replica resource mapping: {replica_resource_mapping}") return replica_resource_mapping def _create_runners(self): - assert self._config.model_tensor_parallel_degree > 1 or self._config.model_pipeline_parallel_degree > 1 + assert ( + self._config.model_tensor_parallel_degree > 1 + or self._config.model_pipeline_parallel_degree > 1 + ) replica_resource_mapping = self._get_replica_resource_mapping() @@ -290,11 +296,14 @@ def _create_runners(self): for replica_id in range(self._config.cluster_num_replicas): runners.append( - runner_class.options(resources={ - replica_resource_mapping[str(replica_id)][0][0]: - 0.01, - }, ).remote(replica_id, self._config, - replica_resource_mapping[str(replica_id)])) + runner_class.options( + resources={ + replica_resource_mapping[str(replica_id)][0][0]: 0.01, + }, + ).remote( + replica_id, self._config, replica_resource_mapping[str(replica_id)] + ) + ) return runners @@ -306,15 +315,11 @@ def _create_aggregate_metric_store(self): wandb_project=self._config.metrics_store_wandb_project, wandb_group=self._config.metrics_store_wandb_group, wandb_run_name=self._config.metrics_store_wandb_run_name, - enable_op_level_metrics=self._config. - metrics_store_enable_op_level_metrics, - enable_cpu_op_level_metrics=self._config. - metrics_store_enable_cpu_op_level_metrics, + enable_op_level_metrics=self._config.metrics_store_enable_op_level_metrics, + enable_cpu_op_level_metrics=self._config.metrics_store_enable_cpu_op_level_metrics, enable_chrome_trace=self._config.write_chrome_trace, - enable_request_outputs=self._config. - metrics_store_enable_request_outputs, - keep_individual_batch_metrics=self._config. - metrics_store_keep_individual_batch_metrics, + enable_request_outputs=self._config.metrics_store_enable_request_outputs, + keep_individual_batch_metrics=self._config.metrics_store_keep_individual_batch_metrics, ) metrics_store = MetricsStore(metric_config) metrics_store.mark_initial_memory_profiling_done() @@ -325,8 +330,7 @@ def run(self): if self._is_multi_replica: ray.get([runner.warmup.remote() for runner in self._runners]) - runner_metrics = ray.get( - [runner.run.remote() for runner in self._runners]) + runner_metrics = ray.get([runner.run.remote() for runner in self._runners]) for runner_metric in runner_metrics: self._aggregate_metric_store.merge(runner_metric) diff --git a/sarathi/benchmark/capacity_search/capacity_search.py b/sarathi/benchmark/capacity_search/capacity_search.py index eaca3ba..ef4f5ad 100644 --- a/sarathi/benchmark/capacity_search/capacity_search.py +++ b/sarathi/benchmark/capacity_search/capacity_search.py @@ -1,24 +1,18 @@ import argparse import glob +import json import os import shlex -import json from subprocess import Popen import pandas as pd import ray import wandb -from sarathi.logger import init_logger -from sarathi.benchmark.capacity_search.config import ( - JobConfig, - BenchmarkConfig, -) -from sarathi.benchmark.capacity_search.ray_utils import ( - ResourceManager, - get_ip, -) +from sarathi.benchmark.capacity_search.config import BenchmarkConfig, JobConfig +from sarathi.benchmark.capacity_search.ray_utils import ResourceManager, get_ip from sarathi.benchmark.types import ReplicaResourceMapping +from sarathi.logger import init_logger logger = init_logger(__name__) @@ -54,15 +48,15 @@ def release_resources(self): if not self.resource_mapping: return - ray.get( - self.resource_manager.release_resources.remote( - self.resource_mapping)) + ray.get(self.resource_manager.release_resources.remote(self.resource_mapping)) def _generate_run_command( self, benchmark_config: BenchmarkConfig, ): - resource_mapping_arg = f"--replica_resource_mapping '{json.dumps(self.resource_mapping)}'" + resource_mapping_arg = ( + f"--replica_resource_mapping '{json.dumps(self.resource_mapping)}'" + ) command = f"python -m sarathi.benchmark.main {benchmark_config.to_args()} {resource_mapping_arg}" logger.debug(f"Running command: {command}", flush=True) @@ -82,17 +76,19 @@ def _is_under_sla( benchmark_config: BenchmarkConfig, ) -> tuple[bool, float, float, str]: scheduling_delay_df = pd.read_csv(scheduling_delay_file) - scheduling_delay = scheduling_delay_df[ - "request_scheduling_delay"].quantile( - self.args.scheduling_delay_slo_quantile) + scheduling_delay = scheduling_delay_df["request_scheduling_delay"].quantile( + self.args.scheduling_delay_slo_quantile + ) tbt_df = pd.read_csv(tbt_file) tbt = tbt_df["decode_token_execution_plus_preemption_time"].quantile( - self.args.tbt_slo_quantile) + self.args.tbt_slo_quantile + ) is_under_scheduling_delay_sla = ( scheduling_delay <= self.args.scheduling_delay_slo_value - and tbt <= self.args.tbt_slo_value) + and tbt <= self.args.tbt_slo_value + ) logger.info( f"{benchmark_config.to_human_readable_name()} - " @@ -100,7 +96,11 @@ def _is_under_sla( f" - TBT (P{self.args.tbt_slo_quantile}): {tbt}", flush=True, ) - return is_under_scheduling_delay_sla, scheduling_delay, tbt, benchmark_config.get_run_id( + return ( + is_under_scheduling_delay_sla, + scheduling_delay, + tbt, + benchmark_config.get_run_id(), ) def is_under_sla(self, qps: float) -> tuple[bool, float, float, str]: @@ -117,13 +117,16 @@ def is_under_sla(self, qps: float) -> tuple[bool, float, float, str]: os.makedirs(run_dir, exist_ok=True) cached_scheduling_delay_file = self._get_result_file( - run_dir, "request_scheduling_delay") + run_dir, "request_scheduling_delay" + ) cached_tbt_file = self._get_result_file( - run_dir, "decode_token_execution_plus_preemption_time") + run_dir, "decode_token_execution_plus_preemption_time" + ) if cached_scheduling_delay_file is not None and cached_tbt_file is not None: - return self._is_under_sla(cached_scheduling_delay_file, - cached_tbt_file, benchmark_config) + return self._is_under_sla( + cached_scheduling_delay_file, cached_tbt_file, benchmark_config + ) command = self._generate_run_command(benchmark_config) @@ -137,14 +140,15 @@ def is_under_sla(self, qps: float) -> tuple[bool, float, float, str]: p.wait() scheduling_delay_file = self._get_result_file( - run_dir, "request_scheduling_delay") + run_dir, "request_scheduling_delay" + ) tbt_file = self._get_result_file( - run_dir, "decode_token_execution_plus_preemption_time") + run_dir, "decode_token_execution_plus_preemption_time" + ) assert ( scheduling_delay_file is not None and tbt_file is not None ), f"Result file not found for {benchmark_config.to_human_readable_name()}" - return self._is_under_sla(scheduling_delay_file, tbt_file, - benchmark_config) + return self._is_under_sla(scheduling_delay_file, tbt_file, benchmark_config) @release_resources_on_failure def search(self): @@ -171,8 +175,7 @@ def search(self): for _ in range(self.args.max_iterations): logger.info(f"Searching between {left} and {right}", flush=True) # stopping condition - we have reached the minimum granularity - if abs(left - - right) < self.args.min_search_granularity * qps / 100: + if abs(left - right) < self.args.min_search_granularity * qps / 100: break qps = (left + right) / 2 @@ -186,8 +189,7 @@ def search(self): print(f"Searching between {left} and {right} - qps: {qps}", flush=True) - is_under_sla, scheduling_delay, tbt, run_id = self.is_under_sla( - qps) + is_under_sla, scheduling_delay, tbt, run_id = self.is_under_sla(qps) if scheduling_delay is None: break diff --git a/sarathi/benchmark/capacity_search/config/__init__.py b/sarathi/benchmark/capacity_search/config/__init__.py index e552ce8..d0e91b3 100644 --- a/sarathi/benchmark/capacity_search/config/__init__.py +++ b/sarathi/benchmark/capacity_search/config/__init__.py @@ -1,9 +1,9 @@ from sarathi.benchmark.capacity_search.config.config import ( + BenchmarkConfig, JobConfig, ModelConfig, - SchedulerConfig, ParallelConfig, - BenchmarkConfig, + SchedulerConfig, TraceConfig, ) diff --git a/sarathi/benchmark/capacity_search/config/config.py b/sarathi/benchmark/capacity_search/config/config.py index 6741317..35e20f2 100644 --- a/sarathi/benchmark/capacity_search/config/config.py +++ b/sarathi/benchmark/capacity_search/config/config.py @@ -188,24 +188,26 @@ def to_config_dict(self): def generate_job_configs(cls, config: dict): job_configs = [] for ( - model_config, - trace_config, - scheduler_config, - parallel_config, + model_config, + trace_config, + scheduler_config, + parallel_config, ) in product( - config["models"], - config["traces"], - config["schedulers"], - config["parallel_spec"], + config["models"], + config["traces"], + config["schedulers"], + config["parallel_spec"], ): model_config = ModelConfig(**model_config) trace_config = TraceConfig(**trace_config) scheduler_config = SchedulerConfig(**scheduler_config) parallel_config = ParallelConfig(**parallel_config) - if not model_config.is_parallel_spec_valid(parallel_config.name) \ - or not model_config.is_scheduler_spec_valid(scheduler_config.name) \ - or not model_config.is_traces_valid(trace_config.name): + if ( + not model_config.is_parallel_spec_valid(parallel_config.name) + or not model_config.is_scheduler_spec_valid(scheduler_config.name) + or not model_config.is_traces_valid(trace_config.name) + ): continue job_config = cls( @@ -276,4 +278,6 @@ def to_human_readable_name(self): return f"{self.job_config.get_human_readable_name()}, QPS: {self.qps}, Run id: {self.get_run_id()}" def get_run_dir(self): - return f"{self.output_dir}/runs/{_get_hash(self.job_config.get_key())}/{self.qps}" + return ( + f"{self.output_dir}/runs/{_get_hash(self.job_config.get_key())}/{self.qps}" + ) diff --git a/sarathi/benchmark/capacity_search/main.py b/sarathi/benchmark/capacity_search/main.py index d1eae99..c2f3c19 100644 --- a/sarathi/benchmark/capacity_search/main.py +++ b/sarathi/benchmark/capacity_search/main.py @@ -11,12 +11,12 @@ import json import os import time -import wandb +import wandb import yaml -from sarathi.logger import init_logger from sarathi.benchmark.capacity_search.search_manager import SearchManager +from sarathi.logger import init_logger logger = init_logger(__name__) @@ -31,22 +31,17 @@ def get_args(): ) parser.add_argument("--output-dir", type=str, required=True) parser.add_argument("--config-path", type=str, required=True) - parser.add_argument("--scheduling-delay-slo-value", - type=float, - default=2.0) - parser.add_argument("--scheduling-delay-slo-quantile", - type=float, - default=0.50) + parser.add_argument("--scheduling-delay-slo-value", type=float, default=2.0) + parser.add_argument("--scheduling-delay-slo-quantile", type=float, default=0.50) parser.add_argument("--tbt-slo-value", type=float, default=0.2) parser.add_argument("--tbt-slo-quantile", type=float, default=0.99) parser.add_argument("--max-iterations", type=int, default=20) - parser.add_argument("--time-limit", - type=int, - default=30, - help="Time limit in minutes") - parser.add_argument("--debug", - action="store_true", - help="Print debug logs and commands") + parser.add_argument( + "--time-limit", type=int, default=30, help="Time limit in minutes" + ) + parser.add_argument( + "--debug", action="store_true", help="Print debug logs and commands" + ) parser.add_argument("--wandb-project", type=str, default=None) parser.add_argument("--wandb-sweep-name", type=str, default=None) parser.add_argument("--wandb-sweep-id", type=str, default=None) @@ -54,7 +49,9 @@ def get_args(): args = parser.parse_args() if args.wandb_project: - assert args.wandb_sweep_name or args.wandb_sweep_id, "wandb-sweep-name/id is required with wandb-project" + assert ( + args.wandb_sweep_name or args.wandb_sweep_id + ), "wandb-sweep-name/id is required with wandb-project" return args @@ -64,9 +61,12 @@ def get_args(): config = yaml.safe_load(open(args.config_path)) - assert (args.scheduling_delay_slo_quantile >= 0 - and args.scheduling_delay_slo_quantile <= 1 - and args.tbt_slo_quantile >= 0 and args.tbt_slo_quantile <= 1) + assert ( + args.scheduling_delay_slo_quantile >= 0 + and args.scheduling_delay_slo_quantile <= 1 + and args.tbt_slo_quantile >= 0 + and args.tbt_slo_quantile <= 1 + ) os.makedirs(args.output_dir, exist_ok=True) diff --git a/sarathi/benchmark/capacity_search/ray_utils.py b/sarathi/benchmark/capacity_search/ray_utils.py index 1d606f5..8a8a293 100644 --- a/sarathi/benchmark/capacity_search/ray_utils.py +++ b/sarathi/benchmark/capacity_search/ray_utils.py @@ -14,7 +14,8 @@ def get_ip() -> str: def get_nodes() -> list[str]: cluster_resources_keys = list(ray.available_resources().keys()) ip_addresses = [ - x for x in cluster_resources_keys + x + for x in cluster_resources_keys if x.startswith("node:") and x != "node:__internal_head__" ] return ip_addresses @@ -42,30 +43,34 @@ def __init__(self): assert self._num_nodes > 0, "No nodes found in the cluster" assert self._num_total_gpus > 0, "No GPUs found in the cluster" - assert self._num_total_gpus % self._num_nodes == 0, ( - f"Number of GPUs ({self._num_total_gpus}) is not divisible by the number of nodes ({self._num_nodes})" - ) + assert ( + self._num_total_gpus % self._num_nodes == 0 + ), f"Number of GPUs ({self._num_total_gpus}) is not divisible by the number of nodes ({self._num_nodes})" self._gpus_per_node = int(self._num_total_gpus // self._num_nodes) self._gpu_free_map = { - node: [True] * self._gpus_per_node - for node in self._nodes + node: [True] * self._gpus_per_node for node in self._nodes } self._node_free_map = {node: True for node in self._nodes} def get_replica_resource_mapping( - self, num_gpus: int) -> Optional[ReplicaResourceMapping]: + self, num_gpus: int + ) -> Optional[ReplicaResourceMapping]: """ Assign node and gpu for a job Note that right now we only support single replica mapping """ - assert num_gpus <= self._num_total_gpus, f"Requested {num_gpus} GPUs, but only {self._num_total_gpus} are present in the cluster" + assert ( + num_gpus <= self._num_total_gpus + ), f"Requested {num_gpus} GPUs, but only {self._num_total_gpus} are present in the cluster" is_multi_node = num_gpus > self._gpus_per_node if is_multi_node: - assert num_gpus % self._gpus_per_node == 0, f"Number of GPUs ({num_gpus}) is not divisible by the number of GPUs per node ({self._gpus_per_node})" + assert ( + num_gpus % self._gpus_per_node == 0 + ), f"Number of GPUs ({num_gpus}) is not divisible by the number of GPUs per node ({self._gpus_per_node})" num_nodes = num_gpus // self._gpus_per_node num_free_nodes = sum(self._node_free_map.values()) @@ -106,8 +111,7 @@ def get_replica_resource_mapping( # currently we only support single replica allocation return {} - def release_resources(self, - replica_resource_mapping: ReplicaResourceMapping): + def release_resources(self, replica_resource_mapping: ReplicaResourceMapping): for resource_mapping in replica_resource_mapping.values(): for node, gpu_id in resource_mapping: self._gpu_free_map[node][gpu_id] = True @@ -128,8 +132,9 @@ def map(self, func, job_configs): remote_func = ray.remote(func) - job_configs_with_num_gpus = [(job_config, job_config.get_num_gpus()) - for job_config in job_configs] + job_configs_with_num_gpus = [ + (job_config, job_config.get_num_gpus()) for job_config in job_configs + ] # this reduces fragmentation job_configs_with_num_gpus.sort(key=lambda x: x[1]) @@ -140,16 +145,16 @@ def map(self, func, job_configs): promises = get_ready_promises(promises) replica_resource_mapping = ray.get( - self._resource_manager.get_replica_resource_mapping.remote( - num_gpus)) + self._resource_manager.get_replica_resource_mapping.remote(num_gpus) + ) time.sleep(0.1) # launch the task runner_node = replica_resource_mapping["0"][0][ - 0] # replica 0, first worker, node - promise = remote_func.options(resources={ - runner_node: 0.001 - }).remote(self._resource_manager, replica_resource_mapping, - job_config) + 0 + ] # replica 0, first worker, node + promise = remote_func.options(resources={runner_node: 0.001}).remote( + self._resource_manager, replica_resource_mapping, job_config + ) promises.append(promise) return ray.get(promises) diff --git a/sarathi/benchmark/capacity_search/search_manager.py b/sarathi/benchmark/capacity_search/search_manager.py index fd4c480..fc8af5f 100644 --- a/sarathi/benchmark/capacity_search/search_manager.py +++ b/sarathi/benchmark/capacity_search/search_manager.py @@ -2,14 +2,14 @@ import ray -from sarathi.logger import init_logger from sarathi.benchmark.capacity_search.capacity_search import CapacitySearch from sarathi.benchmark.capacity_search.config import JobConfig from sarathi.benchmark.capacity_search.ray_utils import ( - ResourceManager, RayParallelRunner, + ResourceManager, ) from sarathi.benchmark.types import ReplicaResourceMapping +from sarathi.logger import init_logger logger = init_logger(__name__) @@ -49,13 +49,12 @@ def run(self): ray_parallel_runner = RayParallelRunner() - remote_func = ( - lambda resource_manager, resource_mapping, job_config: run_search( - job_config, - self.args, - resource_manager, - resource_mapping, - )) + remote_func = lambda resource_manager, resource_mapping, job_config: run_search( + job_config, + self.args, + resource_manager, + resource_mapping, + ) all_results = ray_parallel_runner.map( remote_func, job_configs, diff --git a/sarathi/benchmark/config/config.py b/sarathi/benchmark/config/config.py index 18498d0..7622c96 100644 --- a/sarathi/benchmark/config/config.py +++ b/sarathi/benchmark/config/config.py @@ -4,19 +4,19 @@ import yaml -from sarathi.logger import init_logger from sarathi.benchmark.constants import DEFAULT_CONFIG_FILE +from sarathi.logger import init_logger logger = init_logger(__name__) def custom_bool(val): - if val.lower() in ('yes', 'true', 't', 'y', '1'): + if val.lower() in ("yes", "true", "t", "y", "1"): return True - elif val.lower() in ('no', 'false', 'f', 'n', '0'): + elif val.lower() in ("no", "false", "f", "n", "0"): return False else: - raise argparse.ArgumentTypeError('Boolean value expected.') + raise argparse.ArgumentTypeError("Boolean value expected.") class Config: @@ -28,7 +28,7 @@ def __getattr__(self, name): return self._args.get(name, None) def __reduce__(self): - return self.__class__, (self._args, ) + return self.__class__, (self._args,) class ConfigParser: @@ -67,22 +67,20 @@ def _update_namespace(self, config_dict, parent_key=""): self._parser.add_argument( f"--{arg_name}", type=custom_bool, - nargs='?', + nargs="?", const=True, default=value, ) elif arg_name in [ - "model_max_model_len", - "vllm_scheduler_max_tokens_in_batch", - "time_limit", + "model_max_model_len", + "vllm_scheduler_max_tokens_in_batch", + "time_limit", ]: - self._parser.add_argument(f"--{arg_name}", - default=value, - type=int) + self._parser.add_argument(f"--{arg_name}", default=value, type=int) else: - self._parser.add_argument(f"--{arg_name}", - default=value, - type=type(value)) + self._parser.add_argument( + f"--{arg_name}", default=value, type=type(value) + ) def get_config(self): return Config(self._args.__dict__) @@ -91,8 +89,7 @@ def get_yaml(self): return yaml.dump(self._args.__dict__, default_flow_style=False) def _write_yaml_to_file(self): - with open(f"{self._args.output_dir}/benchmark_config.yml", - "w") as file: + with open(f"{self._args.output_dir}/benchmark_config.yml", "w") as file: file.write(self.get_yaml()) def to_dict(self): diff --git a/sarathi/benchmark/constants.py b/sarathi/benchmark/constants.py index 80beb64..b18e2ac 100644 --- a/sarathi/benchmark/constants.py +++ b/sarathi/benchmark/constants.py @@ -4,6 +4,6 @@ DEFAULT_CONFIG_FILE = f"{ROOT_DIR}/config/default.yml" LOGGER_FORMAT = ( - "[%(asctime)s][%(filename)s:%(lineno)d:%(funcName)s][%(levelname)s] %(message)s" + "[%(asctime)s][%(filename)s:%(lineno)d:%(funcName)s]" "[%(levelname)s] %(message)s" ) LOGGER_TIME_FORMAT = "%H:%M:%S" diff --git a/sarathi/benchmark/main.py b/sarathi/benchmark/main.py index b9b6ade..95a87d8 100644 --- a/sarathi/benchmark/main.py +++ b/sarathi/benchmark/main.py @@ -1,8 +1,8 @@ import logging +from sarathi.benchmark.benchmark_runner import BenchmarkRunnerLauncher from sarathi.benchmark.config import ConfigParser from sarathi.benchmark.constants import LOGGER_FORMAT, LOGGER_TIME_FORMAT -from sarathi.benchmark.benchmark_runner import BenchmarkRunnerLauncher from sarathi.benchmark.utils.random import set_seeds @@ -12,9 +12,9 @@ def main(): set_seeds(config.seed) log_level = getattr(logging, config.log_level.upper()) - logging.basicConfig(format=LOGGER_FORMAT, - level=log_level, - datefmt=LOGGER_TIME_FORMAT) + logging.basicConfig( + format=LOGGER_FORMAT, level=log_level, datefmt=LOGGER_TIME_FORMAT + ) runner = BenchmarkRunnerLauncher(config) runner.run() diff --git a/sarathi/benchmark/request_generator/__init__.py b/sarathi/benchmark/request_generator/__init__.py index 068baaa..1c6c729 100644 --- a/sarathi/benchmark/request_generator/__init__.py +++ b/sarathi/benchmark/request_generator/__init__.py @@ -1,2 +1,3 @@ from sarathi.benchmark.request_generator.request_generator_registry import ( - RequestGeneratorRegistry, ) + RequestGeneratorRegistry, +) diff --git a/sarathi/benchmark/request_generator/fixed_request_length_generator.py b/sarathi/benchmark/request_generator/fixed_request_length_generator.py index 8a7a459..765dd06 100644 --- a/sarathi/benchmark/request_generator/fixed_request_length_generator.py +++ b/sarathi/benchmark/request_generator/fixed_request_length_generator.py @@ -1,11 +1,14 @@ from typing import Tuple from sarathi.benchmark.request_generator.base_request_length_generator import ( - BaseRequestLengthGenerator, ) + BaseRequestLengthGenerator, +) class FixedRequestLengthGenerator(BaseRequestLengthGenerator): def get_next_num_tokens(self) -> Tuple[float, float]: - return self._config.fixed_request_length_generator_prefill_tokens, \ - self._config.fixed_request_length_generator_decode_tokens + return ( + self._config.fixed_request_length_generator_prefill_tokens, + self._config.fixed_request_length_generator_decode_tokens, + ) diff --git a/sarathi/benchmark/request_generator/gamma_request_interval_generator.py b/sarathi/benchmark/request_generator/gamma_request_interval_generator.py index 2a04c08..86925b1 100644 --- a/sarathi/benchmark/request_generator/gamma_request_interval_generator.py +++ b/sarathi/benchmark/request_generator/gamma_request_interval_generator.py @@ -1,7 +1,8 @@ from scipy.stats import gamma from sarathi.benchmark.request_generator.base_request_interval_generator import ( - BaseRequestIntervalGenerator, ) + BaseRequestIntervalGenerator, +) class GammaRequestIntervalGenerator(BaseRequestIntervalGenerator): diff --git a/sarathi/benchmark/request_generator/poisson_request_interval_generator.py b/sarathi/benchmark/request_generator/poisson_request_interval_generator.py index 1e91513..b2f7800 100644 --- a/sarathi/benchmark/request_generator/poisson_request_interval_generator.py +++ b/sarathi/benchmark/request_generator/poisson_request_interval_generator.py @@ -2,7 +2,8 @@ import random from sarathi.benchmark.request_generator.base_request_interval_generator import ( - BaseRequestIntervalGenerator, ) + BaseRequestIntervalGenerator, +) class PoissonRequestIntervalGenerator(BaseRequestIntervalGenerator): diff --git a/sarathi/benchmark/request_generator/request_generator_registry.py b/sarathi/benchmark/request_generator/request_generator_registry.py index daa3755..9aa7942 100644 --- a/sarathi/benchmark/request_generator/request_generator_registry.py +++ b/sarathi/benchmark/request_generator/request_generator_registry.py @@ -1,7 +1,9 @@ from sarathi.benchmark.request_generator.synthetic_request_generator import ( - SyntheticRequestGenerator, ) + SyntheticRequestGenerator, +) from sarathi.benchmark.request_generator.trace_replay_request_generator import ( - TraceReplayRequestGenerator, ) + TraceReplayRequestGenerator, +) from sarathi.benchmark.types import RequestGeneratorType from sarathi.utils.base_registry import BaseRegistry @@ -13,7 +15,9 @@ def get_key_from_str(cls, key_str: str) -> RequestGeneratorType: return RequestGeneratorType.from_str(key_str) -RequestGeneratorRegistry.register(RequestGeneratorType.SYNTHETIC, - SyntheticRequestGenerator) -RequestGeneratorRegistry.register(RequestGeneratorType.TRACE_REPLAY, - TraceReplayRequestGenerator) +RequestGeneratorRegistry.register( + RequestGeneratorType.SYNTHETIC, SyntheticRequestGenerator +) +RequestGeneratorRegistry.register( + RequestGeneratorType.TRACE_REPLAY, TraceReplayRequestGenerator +) diff --git a/sarathi/benchmark/request_generator/request_interval_generator_registry.py b/sarathi/benchmark/request_generator/request_interval_generator_registry.py index eb9acd6..b760a9c 100644 --- a/sarathi/benchmark/request_generator/request_interval_generator_registry.py +++ b/sarathi/benchmark/request_generator/request_interval_generator_registry.py @@ -1,11 +1,15 @@ from sarathi.benchmark.request_generator.gamma_request_interval_generator import ( - GammaRequestIntervalGenerator, ) + GammaRequestIntervalGenerator, +) from sarathi.benchmark.request_generator.poisson_request_interval_generator import ( - PoissonRequestIntervalGenerator, ) + PoissonRequestIntervalGenerator, +) from sarathi.benchmark.request_generator.static_request_interval_generator import ( - StaticRequestIntervalGenerator, ) + StaticRequestIntervalGenerator, +) from sarathi.benchmark.request_generator.trace_request_interval_generator import ( - TraceRequestIntervalGenerator, ) + TraceRequestIntervalGenerator, +) from sarathi.benchmark.types import RequestIntervalGeneratorType from sarathi.utils.base_registry import BaseRegistry @@ -17,11 +21,15 @@ def get_key_from_str(cls, key_str: str) -> RequestIntervalGeneratorType: return RequestIntervalGeneratorType.from_str(key_str) -RequestIntervalGeneratorRegistry.register(RequestIntervalGeneratorType.GAMMA, - GammaRequestIntervalGenerator) -RequestIntervalGeneratorRegistry.register(RequestIntervalGeneratorType.POISSON, - PoissonRequestIntervalGenerator) -RequestIntervalGeneratorRegistry.register(RequestIntervalGeneratorType.STATIC, - StaticRequestIntervalGenerator) -RequestIntervalGeneratorRegistry.register(RequestIntervalGeneratorType.TRACE, - TraceRequestIntervalGenerator) +RequestIntervalGeneratorRegistry.register( + RequestIntervalGeneratorType.GAMMA, GammaRequestIntervalGenerator +) +RequestIntervalGeneratorRegistry.register( + RequestIntervalGeneratorType.POISSON, PoissonRequestIntervalGenerator +) +RequestIntervalGeneratorRegistry.register( + RequestIntervalGeneratorType.STATIC, StaticRequestIntervalGenerator +) +RequestIntervalGeneratorRegistry.register( + RequestIntervalGeneratorType.TRACE, TraceRequestIntervalGenerator +) diff --git a/sarathi/benchmark/request_generator/request_length_generator_registry.py b/sarathi/benchmark/request_generator/request_length_generator_registry.py index ace3c42..dc05715 100644 --- a/sarathi/benchmark/request_generator/request_length_generator_registry.py +++ b/sarathi/benchmark/request_generator/request_length_generator_registry.py @@ -1,11 +1,15 @@ +from sarathi.benchmark.request_generator.fixed_request_length_generator import ( + FixedRequestLengthGenerator, +) from sarathi.benchmark.request_generator.trace_request_length_generator import ( - TraceRequestLengthGenerator, ) + TraceRequestLengthGenerator, +) from sarathi.benchmark.request_generator.uniform_request_length_generator import ( - UniformRequestLengthGenerator, ) + UniformRequestLengthGenerator, +) from sarathi.benchmark.request_generator.zipf_request_length_generator import ( - ZipfRequestLengthGenerator, ) -from sarathi.benchmark.request_generator.fixed_request_length_generator import ( - FixedRequestLengthGenerator, ) + ZipfRequestLengthGenerator, +) from sarathi.benchmark.types import RequestLengthGeneratorType from sarathi.utils.base_registry import BaseRegistry @@ -17,11 +21,15 @@ def get_key_from_str(cls, key_str: str) -> RequestLengthGeneratorType: return RequestLengthGeneratorType.from_str(key_str) -RequestLengthGeneratorRegistry.register(RequestLengthGeneratorType.ZIPF, - ZipfRequestLengthGenerator) -RequestLengthGeneratorRegistry.register(RequestLengthGeneratorType.UNIFORM, - UniformRequestLengthGenerator) -RequestLengthGeneratorRegistry.register(RequestLengthGeneratorType.TRACE, - TraceRequestLengthGenerator) -RequestLengthGeneratorRegistry.register(RequestLengthGeneratorType.FIXED, - FixedRequestLengthGenerator) +RequestLengthGeneratorRegistry.register( + RequestLengthGeneratorType.ZIPF, ZipfRequestLengthGenerator +) +RequestLengthGeneratorRegistry.register( + RequestLengthGeneratorType.UNIFORM, UniformRequestLengthGenerator +) +RequestLengthGeneratorRegistry.register( + RequestLengthGeneratorType.TRACE, TraceRequestLengthGenerator +) +RequestLengthGeneratorRegistry.register( + RequestLengthGeneratorType.FIXED, FixedRequestLengthGenerator +) diff --git a/sarathi/benchmark/request_generator/static_request_interval_generator.py b/sarathi/benchmark/request_generator/static_request_interval_generator.py index 82a1338..5835fe5 100644 --- a/sarathi/benchmark/request_generator/static_request_interval_generator.py +++ b/sarathi/benchmark/request_generator/static_request_interval_generator.py @@ -1,5 +1,6 @@ from sarathi.benchmark.request_generator.base_request_interval_generator import ( - BaseRequestIntervalGenerator, ) + BaseRequestIntervalGenerator, +) class StaticRequestIntervalGenerator(BaseRequestIntervalGenerator): diff --git a/sarathi/benchmark/request_generator/synthetic_request_generator.py b/sarathi/benchmark/request_generator/synthetic_request_generator.py index 169750b..8c58031 100644 --- a/sarathi/benchmark/request_generator/synthetic_request_generator.py +++ b/sarathi/benchmark/request_generator/synthetic_request_generator.py @@ -1,11 +1,15 @@ from typing import List from sarathi.benchmark.entities import Request -from sarathi.benchmark.request_generator.base_request_generator import BaseRequestGenerator +from sarathi.benchmark.request_generator.base_request_generator import ( + BaseRequestGenerator, +) from sarathi.benchmark.request_generator.request_interval_generator_registry import ( - RequestIntervalGeneratorRegistry, ) + RequestIntervalGeneratorRegistry, +) from sarathi.benchmark.request_generator.request_length_generator_registry import ( - RequestLengthGeneratorRegistry, ) + RequestLengthGeneratorRegistry, +) from sarathi.benchmark.utils.random import set_seeds @@ -17,16 +21,18 @@ def __init__(self, *args, **kwargs): self._seed = self._config.seed self._request_length_generator = RequestLengthGeneratorRegistry.get_from_str( - self._config.synthetic_request_generator_length_provider, - self._config) + self._config.synthetic_request_generator_length_provider, self._config + ) self._request_interval_generator = ( RequestIntervalGeneratorRegistry.get_from_str( - self._config.synthetic_request_generator_interval_provider, - self._config)) + self._config.synthetic_request_generator_interval_provider, self._config + ) + ) def _generate_next_request(self, last_arrived_at: float) -> Request: inter_request_time = ( - self._request_interval_generator.get_next_inter_request_time()) + self._request_interval_generator.get_next_inter_request_time() + ) if inter_request_time is None: return None arrived_at = last_arrived_at + inter_request_time @@ -57,8 +63,7 @@ def _generate_requests(self) -> List[Request]: current_time = request.arrived_at requests.append(request) elif self._config.synthetic_request_generator_num_requests is not None: - for _ in range( - self._config.synthetic_request_generator_num_requests): + for _ in range(self._config.synthetic_request_generator_num_requests): request = self._generate_next_request(current_time) current_time = request.arrived_at requests.append(request) @@ -74,10 +79,11 @@ def _generate_requests(self) -> List[Request]: return requests def generate_requests(self) -> List[Request]: - assert (self._config.synthetic_request_generator_num_requests - or self._config.synthetic_request_generator_duration - or self._config.synthetic_request_generator_interval_provider - == "trace") + assert ( + self._config.synthetic_request_generator_num_requests + or self._config.synthetic_request_generator_duration + or self._config.synthetic_request_generator_interval_provider == "trace" + ) set_seeds(self._seed) @@ -88,8 +94,10 @@ def generate_requests(self) -> List[Request]: # remove any requests that arrived after the time limit if self._config.synthetic_request_generator_duration is not None: requests = [ - request for request in requests if request.arrived_at < - self._config.synthetic_request_generator_duration + request + for request in requests + if request.arrived_at + < self._config.synthetic_request_generator_duration ] return requests diff --git a/sarathi/benchmark/request_generator/trace_replay_request_generator.py b/sarathi/benchmark/request_generator/trace_replay_request_generator.py index b445c6a..72fa1c4 100644 --- a/sarathi/benchmark/request_generator/trace_replay_request_generator.py +++ b/sarathi/benchmark/request_generator/trace_replay_request_generator.py @@ -4,7 +4,9 @@ import pandas as pd from sarathi.benchmark.entities import Request -from sarathi.benchmark.request_generator.base_request_generator import BaseRequestGenerator +from sarathi.benchmark.request_generator.base_request_generator import ( + BaseRequestGenerator, +) logger = logging.getLogger(__name__) @@ -22,49 +24,61 @@ def __init__(self, *args, **kwargs): # load into a pd dataframe self._trace_df = pd.read_csv(self._trace_file) # restrict trace_df to be a subset of rows that have the same date - self._trace_df = self._trace_df[self._trace_df["Date"] == self._config. - trace_request_generator_date] + self._trace_df = self._trace_df[ + self._trace_df["Date"] == self._config.trace_request_generator_date + ] # scale prefill and decode tokens self._trace_df["PromptTokenCount"] = ( - self._trace_df["PromptTokenCount"] * - self._config.trace_request_generator_prefill_scale_factor) + self._trace_df["PromptTokenCount"] + * self._config.trace_request_generator_prefill_scale_factor + ) self._trace_df["CompletionTokenCount"] = ( - self._trace_df["CompletionTokenCount"] * - self._config.trace_request_generator_decode_scale_factor) + self._trace_df["CompletionTokenCount"] + * self._config.trace_request_generator_decode_scale_factor + ) # make sure all the prefill and decode counts are integers - self._trace_df["PromptTokenCount"] = self._trace_df[ - "PromptTokenCount"].astype(int) + self._trace_df["PromptTokenCount"] = self._trace_df["PromptTokenCount"].astype( + int + ) self._trace_df["CompletionTokenCount"] = self._trace_df[ - "CompletionTokenCount"].astype(int) + "CompletionTokenCount" + ].astype(int) # make sure that there is at least one prefill and decode token - self._trace_df["PromptTokenCount"] = self._trace_df[ - "PromptTokenCount"].clip(lower=1) + self._trace_df["PromptTokenCount"] = self._trace_df["PromptTokenCount"].clip( + lower=1 + ) self._trace_df["CompletionTokenCount"] = self._trace_df[ - "CompletionTokenCount"].clip(lower=1) + "CompletionTokenCount" + ].clip(lower=1) # make sure the total does not exceed the max tokens, adjust the prefill tokens if needed - total_tokens = (self._trace_df["PromptTokenCount"] + - self._trace_df["CompletionTokenCount"]) + total_tokens = ( + self._trace_df["PromptTokenCount"] + self._trace_df["CompletionTokenCount"] + ) diff_tokens = total_tokens - self._config.trace_request_generator_max_tokens diff_tokens = diff_tokens.clip(lower=0) self._trace_df["PromptTokenCount"] = ( - self._trace_df["PromptTokenCount"] - diff_tokens) + self._trace_df["PromptTokenCount"] - diff_tokens + ) - assert all(self._trace_df["PromptTokenCount"] + - self._trace_df["CompletionTokenCount"] <= - self._config.trace_request_generator_max_tokens) + assert all( + self._trace_df["PromptTokenCount"] + self._trace_df["CompletionTokenCount"] + <= self._config.trace_request_generator_max_tokens + ) # rescale the time to change QPS self._trace_df["Time"] = ( - self._trace_df["Time"] * - self._config.trace_request_generator_time_scale_factor) + self._trace_df["Time"] + * self._config.trace_request_generator_time_scale_factor + ) # compute pd ratio and log the 25, 50, 75, 90, 95, 99 percentiles - pd_ratio = (self._trace_df["PromptTokenCount"] / - self._trace_df["CompletionTokenCount"]) + pd_ratio = ( + self._trace_df["PromptTokenCount"] / self._trace_df["CompletionTokenCount"] + ) logger.info( f"Loaded trace file {self._trace_file} with {len(self._trace_df)} requests" ) diff --git a/sarathi/benchmark/request_generator/trace_request_interval_generator.py b/sarathi/benchmark/request_generator/trace_request_interval_generator.py index 1b781ff..5f5a1a9 100644 --- a/sarathi/benchmark/request_generator/trace_request_interval_generator.py +++ b/sarathi/benchmark/request_generator/trace_request_interval_generator.py @@ -3,7 +3,8 @@ import pandas as pd from sarathi.benchmark.request_generator.base_request_interval_generator import ( - BaseRequestIntervalGenerator, ) + BaseRequestIntervalGenerator, +) logger = logging.getLogger(__name__) @@ -21,28 +22,32 @@ def __init__(self, *args, **kwargs): # load into a pd dataframe self._trace_df = pd.read_csv(trace_file) - self._trace_df["arrival_time"] = pd.to_datetime( - self._trace_df["arrival_time"]) + self._trace_df["arrival_time"] = pd.to_datetime(self._trace_df["arrival_time"]) # restrict trace_df to be a subset of rows that have the same date self._trace_df = self._trace_df[ - (self._trace_df["arrival_time"] > - self._config.trace_request_interval_generator_start_time) - & (self._trace_df["arrival_time"] < - self._config.trace_request_interval_generator_end_time)] + ( + self._trace_df["arrival_time"] + > self._config.trace_request_interval_generator_start_time + ) + & ( + self._trace_df["arrival_time"] + < self._config.trace_request_interval_generator_end_time + ) + ] # change back to seconds self._trace_df["arrival_time"] = ( - self._trace_df["arrival_time"] - - self._trace_df["arrival_time"].min()) // pd.Timedelta("1s") + self._trace_df["arrival_time"] - self._trace_df["arrival_time"].min() + ) // pd.Timedelta("1s") # rescale the time to change QPS self._trace_df["arrival_time"] = ( - self._trace_df["arrival_time"] * - self._config.trace_request_interval_generator_time_scale_factor) + self._trace_df["arrival_time"] + * self._config.trace_request_interval_generator_time_scale_factor + ) # compute the inter-request time - self._trace_df["inter_request_time"] = self._trace_df[ - "arrival_time"].diff() + self._trace_df["inter_request_time"] = self._trace_df["arrival_time"].diff() self._next_request_idx = 1 @@ -54,7 +59,8 @@ def get_next_inter_request_time(self) -> float: if self._next_request_idx >= len(self._trace_df): return None - inter_request_time = self._trace_df.iloc[ - self._next_request_idx]["inter_request_time"] + inter_request_time = self._trace_df.iloc[self._next_request_idx][ + "inter_request_time" + ] self._next_request_idx += 1 return inter_request_time diff --git a/sarathi/benchmark/request_generator/trace_request_length_generator.py b/sarathi/benchmark/request_generator/trace_request_length_generator.py index dfb1835..b316c22 100644 --- a/sarathi/benchmark/request_generator/trace_request_length_generator.py +++ b/sarathi/benchmark/request_generator/trace_request_length_generator.py @@ -5,7 +5,8 @@ import pandas as pd from sarathi.benchmark.request_generator.base_request_length_generator import ( - BaseRequestLengthGenerator, ) + BaseRequestLengthGenerator, +) logger = logging.getLogger(__name__) @@ -20,52 +21,64 @@ def __init__(self, *args, **kwargs): # scale prefill and decode tokens self._trace_df["num_prefill_tokens"] = ( - self._trace_df["num_prefill_tokens"] * - self._config.trace_request_length_generator_prefill_scale_factor) + self._trace_df["num_prefill_tokens"] + * self._config.trace_request_length_generator_prefill_scale_factor + ) self._trace_df["num_decode_tokens"] = ( - self._trace_df["num_decode_tokens"] * - self._config.trace_request_length_generator_decode_scale_factor) + self._trace_df["num_decode_tokens"] + * self._config.trace_request_length_generator_decode_scale_factor + ) # make sure all the prefill and decode counts are integers self._trace_df["num_prefill_tokens"] = self._trace_df[ - "num_prefill_tokens"].astype(int) + "num_prefill_tokens" + ].astype(int) self._trace_df["num_decode_tokens"] = self._trace_df[ - "num_decode_tokens"].astype(int) + "num_decode_tokens" + ].astype(int) # make sure the total does not exceed the max tokens, adjust the prefill tokens if needed - total_tokens = (self._trace_df["num_prefill_tokens"] + - self._trace_df["num_decode_tokens"]) - diff_tokens = total_tokens - self._config.trace_request_length_generator_max_tokens + total_tokens = ( + self._trace_df["num_prefill_tokens"] + self._trace_df["num_decode_tokens"] + ) + diff_tokens = ( + total_tokens - self._config.trace_request_length_generator_max_tokens + ) diff_tokens = diff_tokens.clip(lower=0) # dedcut the diff tokens from the prefill and decode tokens proportionally - prefill_tokens_ratio = self._trace_df[ - "num_prefill_tokens"] / total_tokens + prefill_tokens_ratio = self._trace_df["num_prefill_tokens"] / total_tokens decode_tokens_ratio = self._trace_df["num_decode_tokens"] / total_tokens - self._trace_df["num_prefill_tokens"] -= (np.ceil( - diff_tokens * prefill_tokens_ratio)).astype(int) + self._trace_df["num_prefill_tokens"] -= ( + np.ceil(diff_tokens * prefill_tokens_ratio) + ).astype(int) - self._trace_df["num_decode_tokens"] -= (np.ceil( - diff_tokens * decode_tokens_ratio)).astype(int) + self._trace_df["num_decode_tokens"] -= ( + np.ceil(diff_tokens * decode_tokens_ratio) + ).astype(int) # make sure that there is at least one prefill and decode token self._trace_df["num_prefill_tokens"] = self._trace_df[ - "num_prefill_tokens"].clip(lower=1) - self._trace_df["num_decode_tokens"] = self._trace_df[ - "num_decode_tokens"].clip(lower=1) + "num_prefill_tokens" + ].clip(lower=1) + self._trace_df["num_decode_tokens"] = self._trace_df["num_decode_tokens"].clip( + lower=1 + ) - assert all(self._trace_df["num_prefill_tokens"] + - self._trace_df["num_decode_tokens"] <= - self._config.trace_request_length_generator_max_tokens) + assert all( + self._trace_df["num_prefill_tokens"] + self._trace_df["num_decode_tokens"] + <= self._config.trace_request_length_generator_max_tokens + ) assert all(self._trace_df["num_prefill_tokens"] > 0) assert all(self._trace_df["num_decode_tokens"] > 0) # compute pd ratio and log the 25, 50, 75, 90, 95, 99 percentiles - pd_ratio = (self._trace_df["num_prefill_tokens"] / - self._trace_df["num_decode_tokens"]) + pd_ratio = ( + self._trace_df["num_prefill_tokens"] / self._trace_df["num_decode_tokens"] + ) logger.info( f"Loaded request length trace file {trace_file} with {len(self._trace_df)} requests" ) @@ -74,8 +87,7 @@ def __init__(self, *args, **kwargs): ) # randomly shuffle the df based on the seed - self._trace_df = self._trace_df.sample(frac=1, - random_state=self._config.seed) + self._trace_df = self._trace_df.sample(frac=1, random_state=self._config.seed) self._next_request_idx = 0 def get_next_num_tokens(self) -> Tuple[float, float]: diff --git a/sarathi/benchmark/request_generator/uniform_request_length_generator.py b/sarathi/benchmark/request_generator/uniform_request_length_generator.py index 5782f3c..692125f 100644 --- a/sarathi/benchmark/request_generator/uniform_request_length_generator.py +++ b/sarathi/benchmark/request_generator/uniform_request_length_generator.py @@ -1,9 +1,10 @@ -import random import math +import random from typing import Tuple from sarathi.benchmark.request_generator.base_request_length_generator import ( - BaseRequestLengthGenerator, ) + BaseRequestLengthGenerator, +) class UniformRequestLengthGenerator(BaseRequestLengthGenerator): @@ -15,9 +16,12 @@ def get_next_num_tokens(self) -> Tuple[float, float]: ) decode_tokens = math.ceil( - total_tokens / - (1 + self._config. - uniform_request_length_generator_prefill_to_decode_ratio)) + total_tokens + / ( + 1 + + self._config.uniform_request_length_generator_prefill_to_decode_ratio + ) + ) prefill_tokens = total_tokens - decode_tokens assert prefill_tokens > 0 and decode_tokens > 0 diff --git a/sarathi/benchmark/request_generator/zipf_request_length_generator.py b/sarathi/benchmark/request_generator/zipf_request_length_generator.py index 3a96efe..80c0c38 100644 --- a/sarathi/benchmark/request_generator/zipf_request_length_generator.py +++ b/sarathi/benchmark/request_generator/zipf_request_length_generator.py @@ -1,7 +1,8 @@ from typing import Tuple from sarathi.benchmark.request_generator.base_request_length_generator import ( - BaseRequestLengthGenerator, ) + BaseRequestLengthGenerator, +) from sarathi.benchmark.utils.zipf_generator import ZipfGenerator @@ -22,7 +23,8 @@ def get_next_num_tokens(self) -> Tuple[float, float]: total_tokens = self._zipf_generator.next() decode_tokens = total_tokens / ( - 1 + self._config.zipf_request_generator_prefill_to_decode_ratio) + 1 + self._config.zipf_request_generator_prefill_to_decode_ratio + ) prefill_tokens = total_tokens - decode_tokens return prefill_tokens, decode_tokens diff --git a/sarathi/benchmark/types/__init__.py b/sarathi/benchmark/types/__init__.py index 9859f9c..df3ec42 100644 --- a/sarathi/benchmark/types/__init__.py +++ b/sarathi/benchmark/types/__init__.py @@ -1,13 +1,18 @@ -from typing import List, Tuple, Dict +from typing import Dict, List, Tuple -from sarathi.utils.base_int_enum import BaseIntEnum from sarathi.benchmark.types.request_generator_type import RequestGeneratorType -from sarathi.benchmark.types.request_interval_generator_type import RequestIntervalGeneratorType -from sarathi.benchmark.types.request_length_generator_type import RequestLengthGeneratorType +from sarathi.benchmark.types.request_interval_generator_type import ( + RequestIntervalGeneratorType, +) +from sarathi.benchmark.types.request_length_generator_type import ( + RequestLengthGeneratorType, +) +from sarathi.utils.base_int_enum import BaseIntEnum ResourceMapping = List[Tuple[str, int]] # List of (node_ip, gpu_id) ReplicaResourceMapping = Dict[ - str, ResourceMapping] # Dict of replica_id -> ResourceMapping + str, ResourceMapping +] # Dict of replica_id -> ResourceMapping __all__ = [ RequestGeneratorType, diff --git a/sarathi/benchmark/types/request_length_generator_type.py b/sarathi/benchmark/types/request_length_generator_type.py index d37249b..b37d42e 100644 --- a/sarathi/benchmark/types/request_length_generator_type.py +++ b/sarathi/benchmark/types/request_length_generator_type.py @@ -5,4 +5,4 @@ class RequestLengthGeneratorType(BaseIntEnum): UNIFORM = 1 ZIPF = 2 TRACE = 3 - FIXED = 4 \ No newline at end of file + FIXED = 4 diff --git a/sarathi/benchmark/utils/zipf_generator.py b/sarathi/benchmark/utils/zipf_generator.py index 49974bd..a289964 100644 --- a/sarathi/benchmark/utils/zipf_generator.py +++ b/sarathi/benchmark/utils/zipf_generator.py @@ -5,8 +5,9 @@ class ZipfGenerator: - def __init__(self, min: int, max: int, theta: float, scramble: bool, - seed: int) -> None: + def __init__( + self, min: int, max: int, theta: float, scramble: bool, seed: int + ) -> None: self._min = min self._max = max self._items = max - min + 1 @@ -15,7 +16,8 @@ def __init__(self, min: int, max: int, theta: float, scramble: bool, self._alpha = 1.0 / (1.0 - self._theta) self._zetan = self._zeta(self._items, self._theta) self._eta = (1 - np.power(2.0 / self._items, 1 - self._theta)) / ( - 1 - self._zeta_2 / (self._zetan + EPS)) + 1 - self._zeta_2 / (self._zetan + EPS) + ) self._scramble = scramble self._seed = seed self._generator = np.random.RandomState(seed) @@ -34,13 +36,12 @@ def _next(self) -> int: return self._min + 1 return self._min + int( - (self._items) * - np.power(self._eta * u - self._eta + 1, self._alpha)) + (self._items) * np.power(self._eta * u - self._eta + 1, self._alpha) + ) def next(self) -> int: retval = self._next() if self._scramble: - retval = self._min + hash(str(retval) + - str(self._seed)) % self._items + retval = self._min + hash(str(retval) + str(self._seed)) % self._items return retval diff --git a/sarathi/config.py b/sarathi/config.py index 22c1da9..afd28c6 100644 --- a/sarathi/config.py +++ b/sarathi/config.py @@ -1,5 +1,5 @@ -from typing import Optional, List, Tuple from abc import ABC +from typing import List, Optional, Tuple import torch from transformers import PretrainedConfig @@ -49,7 +49,7 @@ class ModelConfig: a tag name, or a commit id. If unspecified, will use the default version. max_model_len: Maximum length of a sequence (including prompt and - output). If None, will be derived from the model. + output). If None, will be derived from the model. """ def __init__( @@ -85,19 +85,17 @@ def __init__( self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self.hf_config.dtype = self.dtype - self.max_model_len = _get_and_verify_max_len(self.hf_config, - max_model_len) + self.max_model_len = _get_and_verify_max_len(self.hf_config, max_model_len) self._verify_load_format() self._verify_tokenizer_mode() def _verify_load_format(self) -> None: load_format = self.load_format.lower() - if load_format not in [ - "auto", "pt", "safetensors", "npcache", "dummy" - ]: + if load_format not in ["auto", "pt", "safetensors", "npcache", "dummy"]: raise ValueError( f"Unknown load format: {self.load_format}. Must be one of " - "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.") + "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'." + ) self.load_format = load_format def _verify_tokenizer_mode(self) -> None: @@ -105,7 +103,8 @@ def _verify_tokenizer_mode(self) -> None: if tokenizer_mode not in ["auto", "slow"]: raise ValueError( f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " - "either 'auto' or 'slow'.") + "either 'auto' or 'slow'." + ) self.tokenizer_mode = tokenizer_mode def verify_with_parallel_config( @@ -118,7 +117,8 @@ def verify_with_parallel_config( raise ValueError( f"Total number of attention heads ({total_num_attention_heads})" " must be divisible by tensor parallel size " - f"({tensor_parallel_size}).") + f"({tensor_parallel_size})." + ) total_num_hidden_layers = self.hf_config.num_hidden_layers pipeline_parallel_size = parallel_config.pipeline_parallel_size @@ -126,7 +126,8 @@ def verify_with_parallel_config( raise ValueError( f"Total number of hidden layers ({total_num_hidden_layers}) " "must be divisible by pipeline parallel size " - f"({pipeline_parallel_size}).") + f"({pipeline_parallel_size})." + ) def get_hidden_size(self) -> int: return self.hf_config.hidden_size @@ -143,32 +144,35 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] new_decoder_arch_falcon = ( self.hf_config.model_type in falcon_model_types - and getattr(self.hf_config, "new_decoder_architecture", False)) - if not new_decoder_arch_falcon and getattr(self.hf_config, - "multi_query", False): + and getattr(self.hf_config, "new_decoder_architecture", False) + ) + if not new_decoder_arch_falcon and getattr( + self.hf_config, "multi_query", False + ): # Multi-query attention, only one KV head. return 1 # For Falcon: if getattr(self.hf_config, "n_head_kv", None) is not None: - return (self.hf_config.n_head_kv // - parallel_config.tensor_parallel_size) + return self.hf_config.n_head_kv // parallel_config.tensor_parallel_size # For Falcon-40b/Falcon-180b: if getattr(self.hf_config, "num_kv_heads", None) is not None: - return (self.hf_config.num_kv_heads // - parallel_config.tensor_parallel_size) + return self.hf_config.num_kv_heads // parallel_config.tensor_parallel_size # For LLaMA-2: if getattr(self.hf_config, "num_key_value_heads", None) is not None: - return (self.hf_config.num_key_value_heads // - parallel_config.tensor_parallel_size) + return ( + self.hf_config.num_key_value_heads + // parallel_config.tensor_parallel_size + ) total_num_attention_heads = self.hf_config.num_attention_heads return total_num_attention_heads // parallel_config.tensor_parallel_size def get_num_q_heads(self, parallel_config: "ParallelConfig") -> int: if getattr(self.hf_config, "num_attention_heads", None) is not None: - return (self.hf_config.num_attention_heads // - parallel_config.tensor_parallel_size) - raise ValueError( - "num_attention_heads is not defined in the model config") + return ( + self.hf_config.num_attention_heads + // parallel_config.tensor_parallel_size + ) + raise ValueError("num_attention_heads is not defined in the model config") def get_max_model_len(self) -> int: return self.max_model_len @@ -206,7 +210,8 @@ def _verify_args(self) -> None: if self.gpu_memory_utilization > 1.0: raise ValueError( "GPU memory utilization must be less than 1.0. Got " - f"{self.gpu_memory_utilization}.") + f"{self.gpu_memory_utilization}." + ) class ParallelConfig: @@ -228,8 +233,7 @@ def __init__( if not replica_resource_mapping: replica_resource_mapping = [ - (None, i) - for i in range(pipeline_parallel_size * tensor_parallel_size) + (None, i) for i in range(pipeline_parallel_size * tensor_parallel_size) ] self.replica_resource_mapping = replica_resource_mapping @@ -280,13 +284,17 @@ class VLLMSchedulerConfig(BaseSchedulerConfig): moving from WAITING to RUNNING states. """ - def __init__(self, max_num_seqs: int, max_model_len: int, - num_pipeline_stages: int, - max_num_batched_tokens: int) -> None: + def __init__( + self, + max_num_seqs: int, + max_model_len: int, + num_pipeline_stages: int, + max_num_batched_tokens: int, + ) -> None: super().__init__(max_num_seqs, max_model_len, num_pipeline_stages) - self._max_num_batched_tokens = (max_num_batched_tokens - if max_num_batched_tokens else - max_model_len) + self._max_num_batched_tokens = ( + max_num_batched_tokens if max_num_batched_tokens else max_model_len + ) # Requests with context length upto max_model_len must be schedulable. assert max_model_len <= self._max_num_batched_tokens @@ -380,14 +388,23 @@ def type(self): class MetricsConfig: """Metric configuration.""" - def __init__(self, replica_id: int, write_metrics: bool, output_dir: str, - wandb_project: str, wandb_group: str, wandb_run_name: str, - wandb_sweep_id: str, wandb_run_id: str, - enable_op_level_metrics: bool, - enable_cpu_op_level_metrics: bool, enable_chrome_trace: bool, - enable_request_outputs: bool, - keep_individual_batch_metrics: bool, - model_num_layers: int) -> None: + def __init__( + self, + replica_id: int, + write_metrics: bool, + output_dir: str, + wandb_project: str, + wandb_group: str, + wandb_run_name: str, + wandb_sweep_id: str, + wandb_run_id: str, + enable_op_level_metrics: bool, + enable_cpu_op_level_metrics: bool, + enable_chrome_trace: bool, + enable_request_outputs: bool, + keep_individual_batch_metrics: bool, + model_num_layers: int, + ) -> None: self.replica_id = replica_id self.write_metrics = write_metrics self.output_dir = output_dir @@ -416,7 +433,8 @@ def __str__(self) -> str: f"enable_chrome_trace={self.enable_chrome_trace}, " f"enable_request_outputs={self.enable_request_outputs}, " f"keep_individual_batch_metrics=" - f"{self.keep_individual_batch_metrics})") + f"{self.keep_individual_batch_metrics})" + ) _STR_DTYPE_TO_TORCH_DTYPE = { @@ -470,7 +488,8 @@ def _get_and_verify_dtype( raise ValueError( "Bfloat16 is only supported on GPUs with compute capability " f"of at least 8.0. Your {gpu_name} GPU has compute capability " - f"{compute_capability[0]}.{compute_capability[1]}.") + f"{compute_capability[0]}.{compute_capability[1]}." + ) return torch_dtype @@ -503,21 +522,22 @@ def _get_and_verify_max_len( raise ValueError( "When using rope_scaling, the model's config.json must " "contain one of the following keys to determine the original " - f"maximum length of the model: {possible_keys}") + f"maximum length of the model: {possible_keys}" + ) assert "factor" in rope_scaling scaling_factor = rope_scaling["factor"] if rope_scaling["type"] == "yarn": - derived_max_model_len = rope_scaling[ - "original_max_position_embeddings"] + derived_max_model_len = rope_scaling["original_max_position_embeddings"] derived_max_model_len *= scaling_factor if max_model_len is None: - logger.info( - f"Using the derived maximum model length: {derived_max_model_len}") + logger.info(f"Using the derived maximum model length: {derived_max_model_len}") max_model_len = derived_max_model_len elif max_model_len > derived_max_model_len: - logger.info(f"Applying rope_scaling to the maximum model length: " - f"{derived_max_model_len} -> {max_model_len}") + logger.info( + f"Applying rope_scaling to the maximum model length: " + f"{derived_max_model_len} -> {max_model_len}" + ) # force rope_scaling scaling_factor = max_model_len / derived_max_model_len rope_scaling = {"type": "linear", "factor": scaling_factor} diff --git a/sarathi/core/block_space_manager/base_block_space_manager.py b/sarathi/core/block_space_manager/base_block_space_manager.py index c8bbefa..f51f6c4 100644 --- a/sarathi/core/block_space_manager/base_block_space_manager.py +++ b/sarathi/core/block_space_manager/base_block_space_manager.py @@ -1,6 +1,7 @@ """A block manager that manages token blocks.""" -from typing import Dict, List + from abc import ABC, abstractmethod +from typing import Dict, List from sarathi.core.datatypes.block import PhysicalTokenBlock from sarathi.core.datatypes.sequence import Sequence diff --git a/sarathi/core/block_space_manager/block_space_manager_registry.py b/sarathi/core/block_space_manager/block_space_manager_registry.py index 22f7e8d..fb10cd4 100644 --- a/sarathi/core/block_space_manager/block_space_manager_registry.py +++ b/sarathi/core/block_space_manager/block_space_manager_registry.py @@ -1,9 +1,19 @@ -from sarathi.core.block_space_manager.vllm_block_space_manager import VLLMBlockSpaceManager -from sarathi.core.block_space_manager.orca_block_space_manager import OrcaBlockSpaceManager -from sarathi.core.block_space_manager.faster_transformer_block_space_manager import FasterTransformerBlockSpaceManager -from sarathi.core.block_space_manager.simple_chunking_block_space_manager import SimpleChunkingBlockSpaceManager -from sarathi.core.block_space_manager.sarathi_block_space_manager import SarathiBlockSpaceManager from sarathi.config import SchedulerType +from sarathi.core.block_space_manager.faster_transformer_block_space_manager import ( + FasterTransformerBlockSpaceManager, +) +from sarathi.core.block_space_manager.orca_block_space_manager import ( + OrcaBlockSpaceManager, +) +from sarathi.core.block_space_manager.sarathi_block_space_manager import ( + SarathiBlockSpaceManager, +) +from sarathi.core.block_space_manager.simple_chunking_block_space_manager import ( + SimpleChunkingBlockSpaceManager, +) +from sarathi.core.block_space_manager.vllm_block_space_manager import ( + VLLMBlockSpaceManager, +) from sarathi.utils.base_registry import BaseRegistry @@ -16,9 +26,10 @@ def get_key_from_str(cls, key_str: str) -> SchedulerType: BlockSpaceManagerRegistry.register(SchedulerType.VLLM, VLLMBlockSpaceManager) BlockSpaceManagerRegistry.register(SchedulerType.ORCA, OrcaBlockSpaceManager) -BlockSpaceManagerRegistry.register(SchedulerType.FASTER_TRANSFORMER, - FasterTransformerBlockSpaceManager) -BlockSpaceManagerRegistry.register(SchedulerType.SARATHI, - SarathiBlockSpaceManager) -BlockSpaceManagerRegistry.register(SchedulerType.SIMPLE_CHUNKING, - SimpleChunkingBlockSpaceManager) +BlockSpaceManagerRegistry.register( + SchedulerType.FASTER_TRANSFORMER, FasterTransformerBlockSpaceManager +) +BlockSpaceManagerRegistry.register(SchedulerType.SARATHI, SarathiBlockSpaceManager) +BlockSpaceManagerRegistry.register( + SchedulerType.SIMPLE_CHUNKING, SimpleChunkingBlockSpaceManager +) diff --git a/sarathi/core/block_space_manager/faster_transformer_block_space_manager.py b/sarathi/core/block_space_manager/faster_transformer_block_space_manager.py index c096473..53f86c4 100644 --- a/sarathi/core/block_space_manager/faster_transformer_block_space_manager.py +++ b/sarathi/core/block_space_manager/faster_transformer_block_space_manager.py @@ -1,4 +1,6 @@ -from sarathi.core.block_space_manager.orca_block_space_manager import OrcaBlockSpaceManager +from sarathi.core.block_space_manager.orca_block_space_manager import ( + OrcaBlockSpaceManager, +) class FasterTransformerBlockSpaceManager(OrcaBlockSpaceManager): diff --git a/sarathi/core/block_space_manager/orca_block_space_manager.py b/sarathi/core/block_space_manager/orca_block_space_manager.py index 69fb8e9..b471e1d 100644 --- a/sarathi/core/block_space_manager/orca_block_space_manager.py +++ b/sarathi/core/block_space_manager/orca_block_space_manager.py @@ -1,7 +1,9 @@ from math import ceil +from sarathi.core.block_space_manager.base_block_space_manager import ( + BaseBlockSpaceManager, +) from sarathi.core.datatypes.sequence import Sequence -from sarathi.core.block_space_manager.base_block_space_manager import BaseBlockSpaceManager class OrcaBlockSpaceManager(BaseBlockSpaceManager): diff --git a/sarathi/core/block_space_manager/sarathi_block_space_manager.py b/sarathi/core/block_space_manager/sarathi_block_space_manager.py index dac72a1..b3e65b3 100644 --- a/sarathi/core/block_space_manager/sarathi_block_space_manager.py +++ b/sarathi/core/block_space_manager/sarathi_block_space_manager.py @@ -1,4 +1,6 @@ -from sarathi.core.block_space_manager.vllm_block_space_manager import VLLMBlockSpaceManager +from sarathi.core.block_space_manager.vllm_block_space_manager import ( + VLLMBlockSpaceManager, +) class SarathiBlockSpaceManager(VLLMBlockSpaceManager): diff --git a/sarathi/core/block_space_manager/simple_chunking_block_space_manager.py b/sarathi/core/block_space_manager/simple_chunking_block_space_manager.py index 4bc0e34..e7b64f9 100644 --- a/sarathi/core/block_space_manager/simple_chunking_block_space_manager.py +++ b/sarathi/core/block_space_manager/simple_chunking_block_space_manager.py @@ -1,4 +1,6 @@ -from sarathi.core.block_space_manager.vllm_block_space_manager import VLLMBlockSpaceManager +from sarathi.core.block_space_manager.vllm_block_space_manager import ( + VLLMBlockSpaceManager, +) class SimpleChunkingBlockSpaceManager(VLLMBlockSpaceManager): diff --git a/sarathi/core/block_space_manager/vllm_block_space_manager.py b/sarathi/core/block_space_manager/vllm_block_space_manager.py index 3cc3650..2e90fe8 100644 --- a/sarathi/core/block_space_manager/vllm_block_space_manager.py +++ b/sarathi/core/block_space_manager/vllm_block_space_manager.py @@ -1,5 +1,7 @@ +from sarathi.core.block_space_manager.base_block_space_manager import ( + BaseBlockSpaceManager, +) from sarathi.core.datatypes.sequence import Sequence -from sarathi.core.block_space_manager.base_block_space_manager import BaseBlockSpaceManager class VLLMBlockSpaceManager(BaseBlockSpaceManager): diff --git a/sarathi/core/datatypes/block.py b/sarathi/core/datatypes/block.py index 392ba20..6baa1b4 100644 --- a/sarathi/core/datatypes/block.py +++ b/sarathi/core/datatypes/block.py @@ -1,4 +1,5 @@ """Token blocks.""" + from typing import List _BLANK_TOKEN_ID = -1 @@ -34,11 +35,11 @@ def is_full(self) -> bool: def append_tokens(self, token_ids: List[int]) -> None: assert len(token_ids) <= self.get_num_empty_slots() curr_idx = self.num_tokens - self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids + self.token_ids[curr_idx : curr_idx + len(token_ids)] = token_ids self.num_tokens += len(token_ids) def get_token_ids(self) -> List[int]: - return self.token_ids[:self.num_tokens] + return self.token_ids[: self.num_tokens] def get_last_token_id(self) -> int: assert self.num_tokens > 0 @@ -57,5 +58,7 @@ def __init__( self.block_size = block_size def __repr__(self) -> str: - return (f'PhysicalTokenBlock(device={self.device}, ' - f'block_number={self.block_number})') + return ( + f"PhysicalTokenBlock(device={self.device}, " + f"block_number={self.block_number})" + ) diff --git a/sarathi/core/datatypes/request_output.py b/sarathi/core/datatypes/request_output.py index 5b086e2..a2862ed 100644 --- a/sarathi/core/datatypes/request_output.py +++ b/sarathi/core/datatypes/request_output.py @@ -16,6 +16,7 @@ class RequestOutput: outputs: The output sequences of the request. finished: Whether the whole request is finished. """ + seq_id: str prompt: str prompt_token_ids: List[int] @@ -26,7 +27,12 @@ class RequestOutput: @classmethod def from_seq(cls, seq: Sequence) -> "RequestOutput": - return cls(seq.seq_id, seq.prompt, seq.prompt_token_ids, - seq.output_text, seq.get_output_token_ids(), - seq.is_finished(), - SequenceStatus.get_finished_reason(seq.get_status())) + return cls( + seq.seq_id, + seq.prompt, + seq.prompt_token_ids, + seq.output_text, + seq.get_output_token_ids(), + seq.is_finished(), + SequenceStatus.get_finished_reason(seq.get_status()), + ) diff --git a/sarathi/core/datatypes/sampling_params.py b/sarathi/core/datatypes/sampling_params.py index 6f80406..c6894e4 100644 --- a/sarathi/core/datatypes/sampling_params.py +++ b/sarathi/core/datatypes/sampling_params.py @@ -1,4 +1,5 @@ """Sampling parameters for text generation.""" + from enum import IntEnum from functools import cached_property from typing import List, Union @@ -56,15 +57,16 @@ def __init__( def _verify_args(self) -> None: if self.temperature < 0.0: raise ValueError( - f"temperature must be non-negative, got {self.temperature}.") + f"temperature must be non-negative, got {self.temperature}." + ) if not 0.0 < self.top_p <= 1.0: raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.") if self.top_k < -1 or self.top_k == 0: - raise ValueError(f"top_k must be -1 (disable), or at least 1, " - f"got {self.top_k}.") - if self.max_tokens < 1: raise ValueError( - f"max_tokens must be at least 1, got {self.max_tokens}.") + f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}." + ) + if self.max_tokens < 1: + raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.") def _verify_greedy_sampling(self) -> None: if self.top_p < 1.0 - _SAMPLING_EPS: @@ -79,9 +81,11 @@ def sampling_type(self) -> SamplingType: return SamplingType.RANDOM def __repr__(self) -> str: - return (f"SamplingParams(temperature={self.temperature}, " - f"top_p={self.top_p}, " - f"top_k={self.top_k}, " - f"stop={self.stop}, " - f"ignore_eos={self.ignore_eos}, " - f"max_tokens={self.max_tokens})") + return ( + f"SamplingParams(temperature={self.temperature}, " + f"top_p={self.top_p}, " + f"top_k={self.top_k}, " + f"stop={self.stop}, " + f"ignore_eos={self.ignore_eos}, " + f"max_tokens={self.max_tokens})" + ) diff --git a/sarathi/core/datatypes/scheduler_output.py b/sarathi/core/datatypes/scheduler_output.py index ff58c4c..f6838dc 100644 --- a/sarathi/core/datatypes/scheduler_output.py +++ b/sarathi/core/datatypes/scheduler_output.py @@ -17,29 +17,30 @@ def __init__( self.preempted_seq_ids = preempted_seq_ids self.scheduled_seq_metadata_list = scheduled_seq_metadata_list self.prompt_chunk_lens = [ - metadata.num_prompt_tokens - for metadata in scheduled_seq_metadata_list + metadata.num_prompt_tokens for metadata in scheduled_seq_metadata_list ] self.num_batched_prompt_tokens = sum(self.prompt_chunk_lens) self.num_batched_output_tokens = sum( - metadata.num_output_tokens - for metadata in scheduled_seq_metadata_list) + metadata.num_output_tokens for metadata in scheduled_seq_metadata_list + ) self.num_batched_tokens = sum( - metadata.num_tokens for metadata in scheduled_seq_metadata_list) + metadata.num_tokens for metadata in scheduled_seq_metadata_list + ) def is_empty(self) -> bool: # NOTE: We do not consider the ignored sequence groups. return not self.scheduled_seq_metadata_list def has_no_output(self) -> bool: - return (not self.scheduled_seq_metadata_list - and not self.ignored_seq_ids and not self.preempted_seq_ids) + return ( + not self.scheduled_seq_metadata_list + and not self.ignored_seq_ids + and not self.preempted_seq_ids + ) @property def seq_ids(self) -> List[str]: - return [ - metadata.seq_id for metadata in self.scheduled_seq_metadata_list - ] + return [metadata.seq_id for metadata in self.scheduled_seq_metadata_list] def __repr__(self) -> str: return ( @@ -47,4 +48,5 @@ def __repr__(self) -> str: f"new_seqs={self.new_seqs}, " f"ignored_seq_ids={self.ignored_seq_ids}, " f"preempted_seq_ids={self.preempted_seq_ids}, " - f"scheduled_seq_metadata_list={self.scheduled_seq_metadata_list})") + f"scheduled_seq_metadata_list={self.scheduled_seq_metadata_list})" + ) diff --git a/sarathi/core/datatypes/sequence.py b/sarathi/core/datatypes/sequence.py index 4e9b0d7..c232d5c 100644 --- a/sarathi/core/datatypes/sequence.py +++ b/sarathi/core/datatypes/sequence.py @@ -1,10 +1,11 @@ """Sequence and its related classes.""" + from typing import List, Optional from sarathi.core.datatypes.block import LogicalTokenBlock from sarathi.core.datatypes.sampling_params import SamplingParams -from sarathi.core.datatypes.sequence_status import SequenceStatus from sarathi.core.datatypes.sequence_state import SequenceState +from sarathi.core.datatypes.sequence_status import SequenceStatus class Sequence: @@ -79,8 +80,7 @@ def _append_tokens_to_blocks(self, token_ids: List[int]) -> None: last_block = self.logical_token_blocks[-1] num_empty_slots = last_block.get_num_empty_slots() - last_block.append_tokens(token_ids[cursor:cursor + - num_empty_slots]) + last_block.append_tokens(token_ids[cursor : cursor + num_empty_slots]) cursor += num_empty_slots def update_prompt_tokens_processed(self, num_tokens: int) -> None: @@ -136,8 +136,9 @@ def get_next_prompt_chunk_token_ids(self, chunk_size: int) -> List[int]: return self.prompt_token_ids[start:end] def get_next_prompt_chunk_len(self, chunk_size: int) -> int: - return min(chunk_size, - len(self.prompt_token_ids) - self.prompt_tokens_processed) + return min( + chunk_size, len(self.prompt_token_ids) - self.prompt_tokens_processed + ) def is_finished(self) -> bool: return SequenceStatus.is_finished(self.get_status()) @@ -167,7 +168,7 @@ def check_stop(self) -> None: if self.output_text.endswith(stop_str): # Truncate the output text so that the stop string is # not included in the output. - self.output_text = self.output_text[:-len(stop_str)] + self.output_text = self.output_text[: -len(stop_str)] self.set_status(SequenceStatus.FINISHED_STOPPED) return @@ -177,15 +178,18 @@ def check_stop(self) -> None: return # Check if the sequence has generated the EOS token. - if ((not self.sampling_params.ignore_eos) - and self.get_last_token_id() == self.eos_token_id): + if ( + not self.sampling_params.ignore_eos + ) and self.get_last_token_id() == self.eos_token_id: self.set_status(SequenceStatus.FINISHED_STOPPED) return def __repr__(self) -> str: - return (f"Sequence(seq_id={self.seq_id}, " - f"status={self.get_status().name}, " - f"num_blocks={len(self.logical_token_blocks)})") + return ( + f"Sequence(seq_id={self.seq_id}, " + f"status={self.get_status().name}, " + f"num_blocks={len(self.logical_token_blocks)})" + ) class SequenceScheduleMetadata: @@ -239,8 +243,10 @@ def from_sequence( return cls(seq_id=seq.seq_id, prompt_chunk_len=prompt_chunk_len) def __str__(self) -> str: - return (f"SequenceScheduleMetadata(seq_id={self.seq_id}, " - f"prompt_chunk_len={self.prompt_chunk_len})") + return ( + f"SequenceScheduleMetadata(seq_id={self.seq_id}, " + f"prompt_chunk_len={self.prompt_chunk_len})" + ) def __repr__(self) -> str: return self.__str__() @@ -283,8 +289,10 @@ def num_tokens(self) -> int: return max(self.prompt_chunk_len, 1) def __str__(self) -> str: - return (f"SequenceMetadata(seq_id={self.seq.seq_id}, " - f"prompt_chunk_len={self.prompt_chunk_len})") + return ( + f"SequenceMetadata(seq_id={self.seq.seq_id}, " + f"prompt_chunk_len={self.prompt_chunk_len})" + ) def __repr__(self) -> str: return self.__str__() @@ -307,14 +315,15 @@ def __init__( self.output_token = output_token def __repr__(self) -> str: - return (f"SamplerOutput(seq_id={self.seq_id}, " - f"output_token={self.output_token}))") + return ( + f"SamplerOutput(seq_id={self.seq_id}, " + f"output_token={self.output_token}))" + ) def __eq__(self, other: object) -> bool: if not isinstance(other, SamplerOutput): raise NotImplementedError() - return (self.seq_id == other.seq_id - and self.output_token == other.output_token) + return self.seq_id == other.seq_id and self.output_token == other.output_token SamplerOutputs = List[SamplerOutput] diff --git a/sarathi/core/datatypes/sequence_state.py b/sarathi/core/datatypes/sequence_state.py index 4fc99de..3c61d2c 100644 --- a/sarathi/core/datatypes/sequence_state.py +++ b/sarathi/core/datatypes/sequence_state.py @@ -74,12 +74,17 @@ def prompt_processing_completed_at(self) -> Optional[float]: @property def e2e_time(self) -> Optional[float]: - return self._completed_at - self._arrived_at if self._completed_at is not None else None + return ( + self._completed_at - self._arrived_at + if self._completed_at is not None + else None + ) @property def e2e_time_piecewise_normalized(self) -> float: - return self.scheduling_delay + (self.execution_plus_preemption_time / - self._num_output_tokens) + return self.scheduling_delay + ( + self.execution_plus_preemption_time / self._num_output_tokens + ) @property def e2e_time_normalized(self) -> float: @@ -87,41 +92,68 @@ def e2e_time_normalized(self) -> float: @property def e2e_prefill_time(self) -> Optional[float]: - return self._prompt_processing_completed_at - self._arrived_at if self._prompt_processing_completed_at is not None else None + return ( + self._prompt_processing_completed_at - self._arrived_at + if self._prompt_processing_completed_at is not None + else None + ) @property def e2e_prefill_time_normalized(self) -> Optional[float]: - return (self.e2e_prefill_time / self._num_prompt_tokens - ) if self._prompt_processing_completed_at is not None else None + return ( + (self.e2e_prefill_time / self._num_prompt_tokens) + if self._prompt_processing_completed_at is not None + else None + ) @property def e2e_prefill_time_piecewise_normalized(self) -> Optional[float]: - return self.scheduling_delay + ( - self.prefill_execution_plus_preemption_time / - self._num_prompt_tokens - ) if self._prompt_processing_completed_at else None + return ( + self.scheduling_delay + + (self.prefill_execution_plus_preemption_time / self._num_prompt_tokens) + if self._prompt_processing_completed_at + else None + ) @property def prefill_execution_plus_preemption_time(self) -> float: - return self._prompt_processing_completed_at - self._scheduled_at if self._prompt_processing_completed_at is not None else None + return ( + self._prompt_processing_completed_at - self._scheduled_at + if self._prompt_processing_completed_at is not None + else None + ) @property def decode_execution_plus_preemption_time(self) -> float: - return self._completed_at - self._prompt_processing_completed_at if self._completed_at is not None else None + return ( + self._completed_at - self._prompt_processing_completed_at + if self._completed_at is not None + else None + ) @property - def prefill_execution_plus_preemption_time_normalized( - self) -> Optional[float]: - return self.prefill_execution_plus_preemption_time / self._num_prompt_tokens if self.prefill_execution_plus_preemption_time else None + def prefill_execution_plus_preemption_time_normalized(self) -> Optional[float]: + return ( + self.prefill_execution_plus_preemption_time / self._num_prompt_tokens + if self.prefill_execution_plus_preemption_time + else None + ) @property - def decode_execution_plus_preemption_time_normalized( - self) -> Optional[float]: - return self.decode_execution_plus_preemption_time / self._num_output_tokens if self.decode_execution_plus_preemption_time else None + def decode_execution_plus_preemption_time_normalized(self) -> Optional[float]: + return ( + self.decode_execution_plus_preemption_time / self._num_output_tokens + if self.decode_execution_plus_preemption_time + else None + ) @property def scheduling_delay(self) -> Optional[float]: - return self._scheduled_at - self._arrived_at if self._scheduled_at is not None else None + return ( + self._scheduled_at - self._arrived_at + if self._scheduled_at is not None + else None + ) @property def execution_time(self) -> float: @@ -160,7 +192,8 @@ def is_ignore_finished(self) -> bool: return self._is_ignore_finished def _handle_transitions_from_waiting_status( - self, current_time: float, status: SequenceStatus) -> None: + self, current_time: float, status: SequenceStatus + ) -> None: if status == SequenceStatus.RUNNING: # request is starting execution now if self._scheduled_at is None: @@ -186,7 +219,8 @@ def _handle_transitions_from_waiting_status( ) def _handle_transitions_from_running_status( - self, current_time: float, status: SequenceStatus) -> None: + self, current_time: float, status: SequenceStatus + ) -> None: self._execution_time += current_time - self._last_execution_start_at if status == SequenceStatus.PAUSED: @@ -200,14 +234,15 @@ def _handle_transitions_from_running_status( f"Invalid state transition from {self._status} to {status} for request {self._id}." ) - def _handle_transitions_from_paused_status(self, current_time: float, - status: SequenceStatus) -> None: + def _handle_transitions_from_paused_status( + self, current_time: float, status: SequenceStatus + ) -> None: self._preempted_time += current_time - self._last_pause_at - if status in [ - SequenceStatus.FINISHED_STOPPED, - SequenceStatus.FINISHED_LENGTH_CAPPED - ]: + if ( + status == SequenceStatus.FINISHED_STOPPED + or status == SequenceStatus.FINISHED_LENGTH_CAPPED + ): self._is_completed = True self._completed_at = current_time elif status == SequenceStatus.RUNNING: @@ -247,6 +282,8 @@ def on_token_generated(self) -> None: if not self._last_token_generated_at: self._last_token_generation_time = 0 else: - self._last_token_generation_time = current_time - self._last_token_generated_at + self._last_token_generation_time = ( + current_time - self._last_token_generated_at + ) self._last_token_generated_at = current_time diff --git a/sarathi/core/policy.py b/sarathi/core/policy.py index 092fb76..fd9f49f 100644 --- a/sarathi/core/policy.py +++ b/sarathi/core/policy.py @@ -37,7 +37,7 @@ def get_priority( class PolicyFactory: _POLICY_REGISTRY = { - 'fcfs': FCFS, + "fcfs": FCFS, } @classmethod diff --git a/sarathi/core/scheduler/base_scheduler.py b/sarathi/core/scheduler/base_scheduler.py index 9c069fe..e6fafea 100644 --- a/sarathi/core/scheduler/base_scheduler.py +++ b/sarathi/core/scheduler/base_scheduler.py @@ -2,14 +2,15 @@ from abc import ABC, abstractmethod from typing import List, Tuple -from sarathi.config import CacheConfig, BaseSchedulerConfig +from sarathi.config import BaseSchedulerConfig, CacheConfig +from sarathi.core.block_space_manager.block_space_manager_registry import ( + BlockSpaceManagerRegistry, +) +from sarathi.core.datatypes.scheduler_output import SchedulerOutputs +from sarathi.core.datatypes.sequence import Sequence, SequenceStatus from sarathi.core.policy import PolicyFactory from sarathi.logger import init_logger -from sarathi.core.datatypes.sequence import Sequence -from sarathi.core.datatypes.sequence import SequenceStatus -from sarathi.core.datatypes.scheduler_output import SchedulerOutputs from sarathi.metrics.metrics_store import MetricsStore -from sarathi.core.block_space_manager.block_space_manager_registry import BlockSpaceManagerRegistry logger = init_logger(__name__) @@ -123,7 +124,8 @@ def _check_request_prompt_length(self, seq: Sequence) -> bool: if seq.get_len() > self.scheduler_config.max_model_len: logger.warning( f"Input prompt ({seq.get_len()} tokens) is too long" - f" and exceeds limit of {seq.sampling_params.max_tokens}") + f" and exceeds limit of {seq.sampling_params.max_tokens}" + ) seq.set_status(SequenceStatus.FINISHED_IGNORED) self.waiting.pop(0) return False diff --git a/sarathi/core/scheduler/faster_transformer_scheduler.py b/sarathi/core/scheduler/faster_transformer_scheduler.py index 5b1ecdd..b7b0774 100644 --- a/sarathi/core/scheduler/faster_transformer_scheduler.py +++ b/sarathi/core/scheduler/faster_transformer_scheduler.py @@ -2,12 +2,14 @@ from typing import List from sarathi.config import CacheConfig, FasterTransformerSchedulerConfig -from sarathi.logger import init_logger +from sarathi.core.block_space_manager.faster_transformer_block_space_manager import ( + FasterTransformerBlockSpaceManager, +) +from sarathi.core.datatypes.scheduler_output import SchedulerOutputs from sarathi.core.datatypes.sequence import SequenceScheduleMetadata from sarathi.core.datatypes.sequence_status import SequenceStatus from sarathi.core.scheduler.base_scheduler import BaseScheduler -from sarathi.core.datatypes.scheduler_output import SchedulerOutputs -from sarathi.core.block_space_manager.faster_transformer_block_space_manager import FasterTransformerBlockSpaceManager +from sarathi.logger import init_logger logger = init_logger(__name__) @@ -37,7 +39,8 @@ def _schedule(self) -> SchedulerOutputs: assert seq.prompt_processing_finished scheduled_seq_metadata_list.append( - SequenceScheduleMetadata.from_sequence(seq)) + SequenceScheduleMetadata.from_sequence(seq) + ) if scheduled_seq_metadata_list: return SchedulerOutputs( @@ -74,7 +77,8 @@ def _schedule(self) -> SchedulerOutputs: self._allocate(seq) self.running.append(seq) scheduled_seq_metadata_list.append( - SequenceScheduleMetadata.from_sequence(seq)) + SequenceScheduleMetadata.from_sequence(seq) + ) scheduler_outputs = SchedulerOutputs( id=self._iteration_id, diff --git a/sarathi/core/scheduler/orca_scheduler.py b/sarathi/core/scheduler/orca_scheduler.py index abde297..5bc83b9 100644 --- a/sarathi/core/scheduler/orca_scheduler.py +++ b/sarathi/core/scheduler/orca_scheduler.py @@ -2,11 +2,13 @@ from typing import List from sarathi.config import CacheConfig, OrcaSchedulerConfig -from sarathi.logger import init_logger +from sarathi.core.block_space_manager.orca_block_space_manager import ( + OrcaBlockSpaceManager, +) +from sarathi.core.datatypes.scheduler_output import SchedulerOutputs from sarathi.core.datatypes.sequence import SequenceScheduleMetadata from sarathi.core.scheduler.base_scheduler import BaseScheduler -from sarathi.core.datatypes.scheduler_output import SchedulerOutputs -from sarathi.core.block_space_manager.orca_block_space_manager import OrcaBlockSpaceManager +from sarathi.logger import init_logger logger = init_logger(__name__) @@ -38,7 +40,8 @@ def _schedule(self) -> SchedulerOutputs: assert seq.prompt_processing_finished scheduled_seq_metadata_list.append( - SequenceScheduleMetadata.from_sequence(seq)) + SequenceScheduleMetadata.from_sequence(seq) + ) # Optimization: We do not sort the waiting queue since the preempted # sequences are added to the front and the new sequences @@ -65,7 +68,8 @@ def _schedule(self) -> SchedulerOutputs: self._allocate(seq) self.running.append(seq) scheduled_seq_metadata_list.append( - SequenceScheduleMetadata.from_sequence(seq)) + SequenceScheduleMetadata.from_sequence(seq) + ) return SchedulerOutputs( id=self._iteration_id, diff --git a/sarathi/core/scheduler/sarathi_scheduler.py b/sarathi/core/scheduler/sarathi_scheduler.py index c1d886b..7d58008 100644 --- a/sarathi/core/scheduler/sarathi_scheduler.py +++ b/sarathi/core/scheduler/sarathi_scheduler.py @@ -4,11 +4,13 @@ import numpy as np from sarathi.config import CacheConfig, SarathiSchedulerConfig -from sarathi.logger import init_logger +from sarathi.core.block_space_manager.sarathi_block_space_manager import ( + SarathiBlockSpaceManager, +) +from sarathi.core.datatypes.scheduler_output import SchedulerOutputs from sarathi.core.datatypes.sequence import Sequence, SequenceScheduleMetadata from sarathi.core.scheduler.base_scheduler import BaseScheduler -from sarathi.core.datatypes.scheduler_output import SchedulerOutputs -from sarathi.core.block_space_manager.sarathi_block_space_manager import SarathiBlockSpaceManager +from sarathi.logger import init_logger logger = init_logger(__name__) @@ -24,7 +26,9 @@ def __init__( self.prompt_limit = self.scheduler_config.max_model_len self.chunk_size = self.scheduler_config.chunk_size - self.enable_dynamic_chunking_schedule = self.scheduler_config.enable_dynamic_chunking_schedule + self.enable_dynamic_chunking_schedule = ( + self.scheduler_config.enable_dynamic_chunking_schedule + ) # next four params apply only when using dynamic schedule self.low_chunk_size = self.scheduler_config.low_chunk_size self.high_chunk_size = self.scheduler_config.high_chunk_size @@ -38,19 +42,22 @@ def __init__( assert self.high_chunk_size % 32 == 0 self._chunk_sizes = self._compute_chunk_size_schedule() self._tokens_per_stage = int( - np.ceil(self.chunk_schedule_max_tokens / - self.chunk_schedule_stages)) + np.ceil(self.chunk_schedule_max_tokens / self.chunk_schedule_stages) + ) def _compute_chunk_size_schedule(self): # create num_steps equally spaced chunk sizes between low_chunk_size and high_chunk_size - chunk_sizes = np.linspace(self.low_chunk_size, - self.high_chunk_size, - self.chunk_schedule_stages, - dtype=np.int32)[::-1] + chunk_sizes = np.linspace( + self.low_chunk_size, + self.high_chunk_size, + self.chunk_schedule_stages, + dtype=np.int32, + )[::-1] # align each chunk size to the nearest multiple of 32 or self.low_chunk_size round_of_chunk_sizes = min(32, self.low_chunk_size) - chunk_sizes = np.round( - chunk_sizes / round_of_chunk_sizes) * round_of_chunk_sizes + chunk_sizes = ( + np.round(chunk_sizes / round_of_chunk_sizes) * round_of_chunk_sizes + ) chunk_sizes = chunk_sizes.astype(np.int64).tolist() return chunk_sizes @@ -58,14 +65,15 @@ def _compute_chunk_size_schedule(self): def get_block_space_manager_class(self): return SarathiBlockSpaceManager - def _get_seq_next_num_prefill_tokens(self, seq: Sequence, - num_batched_tokens: int) -> int: + def _get_seq_next_num_prefill_tokens( + self, seq: Sequence, num_batched_tokens: int + ) -> int: assert not seq.is_finished() if self.enable_dynamic_chunking_schedule: request_stage_idx = int( - np.ceil(seq.get_num_prompt_tokens_processed() // - self._tokens_per_stage)) + np.ceil(seq.get_num_prompt_tokens_processed() // self._tokens_per_stage) + ) assert request_stage_idx < len(self._chunk_sizes) chunk_size = self._chunk_sizes[request_stage_idx] else: @@ -73,7 +81,8 @@ def _get_seq_next_num_prefill_tokens(self, seq: Sequence, next_num_tokens = min( seq.get_prompt_len() - seq.get_num_prompt_tokens_processed(), - chunk_size - num_batched_tokens) + chunk_size - num_batched_tokens, + ) return next_num_tokens @@ -138,7 +147,8 @@ def _schedule(self) -> SchedulerOutputs: running.append(seq) num_batched_tokens += 1 scheduled_seq_metadata_list.append( - SequenceScheduleMetadata.from_sequence(seq)) + SequenceScheduleMetadata.from_sequence(seq) + ) # now add the requests with prefill incomplete # the memory for all these prefills has already been allocated @@ -147,7 +157,8 @@ def _schedule(self) -> SchedulerOutputs: assert not seq.prompt_processing_finished next_num_prefill_tokens = self._get_seq_next_num_prefill_tokens( - seq, num_batched_tokens) + seq, num_batched_tokens + ) # as long as the request could fit in the batch previously # it should be able to fit in the batch now @@ -161,7 +172,9 @@ def _schedule(self) -> SchedulerOutputs: num_batched_tokens += next_num_prefill_tokens scheduled_seq_metadata_list.append( SequenceScheduleMetadata.from_sequence( - seq, prompt_chunk_len=next_num_prefill_tokens)) + seq, prompt_chunk_len=next_num_prefill_tokens + ) + ) running.append(seq) ###################################################################### @@ -196,7 +209,8 @@ def _schedule(self) -> SchedulerOutputs: # check if we can fit the prefill in the batch next_num_prefill_tokens = self._get_seq_next_num_prefill_tokens( - seq, num_batched_tokens) + seq, num_batched_tokens + ) if next_num_prefill_tokens == 0: break @@ -206,7 +220,9 @@ def _schedule(self) -> SchedulerOutputs: num_batched_tokens += next_num_prefill_tokens scheduled_seq_metadata_list.append( SequenceScheduleMetadata.from_sequence( - seq, prompt_chunk_len=next_num_prefill_tokens)) + seq, prompt_chunk_len=next_num_prefill_tokens + ) + ) running.append(seq) # make sure that prefills are at the start of the batch, so that we don't violate assumptions diff --git a/sarathi/core/scheduler/scheduler_registry.py b/sarathi/core/scheduler/scheduler_registry.py index 7d859c9..f5b2c7e 100644 --- a/sarathi/core/scheduler/scheduler_registry.py +++ b/sarathi/core/scheduler/scheduler_registry.py @@ -1,9 +1,11 @@ -from sarathi.core.scheduler.vllm_scheduler import VLLMScheduler +from sarathi.config import SchedulerType +from sarathi.core.scheduler.faster_transformer_scheduler import ( + FasterTransformerScheduler, +) from sarathi.core.scheduler.orca_scheduler import OrcaScheduler -from sarathi.core.scheduler.faster_transformer_scheduler import FasterTransformerScheduler from sarathi.core.scheduler.sarathi_scheduler import SarathiScheduler from sarathi.core.scheduler.simple_chunking_scheduler import SimpleChunkingScheduler -from sarathi.config import SchedulerType +from sarathi.core.scheduler.vllm_scheduler import VLLMScheduler from sarathi.utils.base_registry import BaseRegistry @@ -16,8 +18,6 @@ def get_key_from_str(cls, key_str: str) -> SchedulerType: SchedulerRegistry.register(SchedulerType.VLLM, VLLMScheduler) SchedulerRegistry.register(SchedulerType.ORCA, OrcaScheduler) -SchedulerRegistry.register(SchedulerType.FASTER_TRANSFORMER, - FasterTransformerScheduler) +SchedulerRegistry.register(SchedulerType.FASTER_TRANSFORMER, FasterTransformerScheduler) SchedulerRegistry.register(SchedulerType.SARATHI, SarathiScheduler) -SchedulerRegistry.register(SchedulerType.SIMPLE_CHUNKING, - SimpleChunkingScheduler) +SchedulerRegistry.register(SchedulerType.SIMPLE_CHUNKING, SimpleChunkingScheduler) diff --git a/sarathi/core/scheduler/simple_chunking_scheduler.py b/sarathi/core/scheduler/simple_chunking_scheduler.py index eedb77d..ac06c7b 100644 --- a/sarathi/core/scheduler/simple_chunking_scheduler.py +++ b/sarathi/core/scheduler/simple_chunking_scheduler.py @@ -3,12 +3,14 @@ from typing import List from sarathi.config import CacheConfig, SimpleChunkingSchedulerConfig -from sarathi.logger import init_logger +from sarathi.core.block_space_manager.vllm_block_space_manager import ( + VLLMBlockSpaceManager, +) +from sarathi.core.datatypes.scheduler_output import SchedulerOutputs from sarathi.core.datatypes.sequence import Sequence, SequenceScheduleMetadata from sarathi.core.datatypes.sequence_status import SequenceStatus from sarathi.core.scheduler.base_scheduler import BaseScheduler -from sarathi.core.datatypes.scheduler_output import SchedulerOutputs -from sarathi.core.block_space_manager.vllm_block_space_manager import VLLMBlockSpaceManager +from sarathi.logger import init_logger logger = init_logger(__name__) @@ -34,13 +36,15 @@ def __init__( def get_block_space_manager_class(self): return VLLMBlockSpaceManager - def _get_seq_next_num_prefill_tokens(self, seq: Sequence, - num_batched_tokens: int) -> int: + def _get_seq_next_num_prefill_tokens( + self, seq: Sequence, num_batched_tokens: int + ) -> int: assert not seq.is_finished() next_num_tokens = min( seq.get_prompt_len() - seq.get_num_prompt_tokens_processed(), - self.chunk_size - num_batched_tokens) + self.chunk_size - num_batched_tokens, + ) return next_num_tokens @@ -75,7 +79,8 @@ def _schedule(self) -> SchedulerOutputs: continue next_num_prefill_tokens = self._get_seq_next_num_prefill_tokens( - seq, num_batched_tokens) + seq, num_batched_tokens + ) if next_num_prefill_tokens == 0: # not enough token space to allocate the sequence @@ -86,7 +91,9 @@ def _schedule(self) -> SchedulerOutputs: running.append(seq) scheduled_seq_metadata_list.append( SequenceScheduleMetadata.from_sequence( - seq, prompt_chunk_len=next_num_prefill_tokens)) + seq, prompt_chunk_len=next_num_prefill_tokens + ) + ) if running: assert not self.running @@ -121,7 +128,8 @@ def _schedule(self) -> SchedulerOutputs: break next_num_prefill_tokens = self._get_seq_next_num_prefill_tokens( - seq, num_batched_tokens) + seq, num_batched_tokens + ) if next_num_prefill_tokens == 0: # not enough space to allocate the sequence @@ -133,7 +141,9 @@ def _schedule(self) -> SchedulerOutputs: num_batched_tokens += next_num_prefill_tokens scheduled_seq_metadata_list.append( SequenceScheduleMetadata.from_sequence( - seq, prompt_chunk_len=next_num_prefill_tokens)) + seq, prompt_chunk_len=next_num_prefill_tokens + ) + ) if scheduled_seq_metadata_list or ignored_seq_ids: self.whose_turn = Turn.DECODE @@ -175,7 +185,8 @@ def _schedule(self) -> SchedulerOutputs: self._append_slot(seq) running.append(seq) scheduled_seq_metadata_list.append( - SequenceScheduleMetadata.from_sequence(seq)) + SequenceScheduleMetadata.from_sequence(seq) + ) self.running = running self.whose_turn = Turn.PREFILL diff --git a/sarathi/core/scheduler/vllm_scheduler.py b/sarathi/core/scheduler/vllm_scheduler.py index 6d5e6a1..41a276c 100644 --- a/sarathi/core/scheduler/vllm_scheduler.py +++ b/sarathi/core/scheduler/vllm_scheduler.py @@ -2,11 +2,13 @@ from typing import List from sarathi.config import CacheConfig, VLLMSchedulerConfig -from sarathi.logger import init_logger +from sarathi.core.block_space_manager.vllm_block_space_manager import ( + VLLMBlockSpaceManager, +) +from sarathi.core.datatypes.scheduler_output import SchedulerOutputs from sarathi.core.datatypes.sequence import Sequence, SequenceScheduleMetadata from sarathi.core.scheduler.base_scheduler import BaseScheduler -from sarathi.core.datatypes.scheduler_output import SchedulerOutputs -from sarathi.core.block_space_manager.vllm_block_space_manager import VLLMBlockSpaceManager +from sarathi.logger import init_logger logger = init_logger(__name__) @@ -20,8 +22,10 @@ def __init__( ) -> None: super().__init__(scheduler_config, cache_config) - self.prompt_limit = min(self.scheduler_config.max_model_len, - self.scheduler_config.max_num_batched_tokens) + self.prompt_limit = min( + self.scheduler_config.max_model_len, + self.scheduler_config.max_num_batched_tokens, + ) def get_block_space_manager_class(self): return VLLMBlockSpaceManager @@ -57,8 +61,10 @@ def _schedule(self) -> SchedulerOutputs: break # If the number of batched tokens exceeds the limit, stop. - if (num_batched_tokens + num_prompt_tokens - > self.scheduler_config.max_num_batched_tokens): + if ( + num_batched_tokens + num_prompt_tokens + > self.scheduler_config.max_num_batched_tokens + ): break if len(self.running) + 1 > self.scheduler_config.max_num_seqs: @@ -68,7 +74,8 @@ def _schedule(self) -> SchedulerOutputs: self._allocate(seq) num_batched_tokens += num_prompt_tokens scheduled_seq_metadata_list.append( - SequenceScheduleMetadata.from_sequence(seq)) + SequenceScheduleMetadata.from_sequence(seq) + ) self.running.append(seq) if scheduled_seq_metadata_list or ignored_seq_ids: @@ -115,7 +122,8 @@ def _schedule(self) -> SchedulerOutputs: self._append_slot(seq) running.append(seq) scheduled_seq_metadata_list.append( - SequenceScheduleMetadata.from_sequence(seq)) + SequenceScheduleMetadata.from_sequence(seq) + ) self.running = running diff --git a/sarathi/core/sequence_manager/base_sequence_manager.py b/sarathi/core/sequence_manager/base_sequence_manager.py index ea99493..5b46d02 100644 --- a/sarathi/core/sequence_manager/base_sequence_manager.py +++ b/sarathi/core/sequence_manager/base_sequence_manager.py @@ -1,15 +1,15 @@ -from abc import abstractmethod, ABC +from abc import ABC, abstractmethod from typing import List, Optional, Tuple +from sarathi.core.datatypes.request_output import RequestOutput +from sarathi.core.datatypes.scheduler_output import SchedulerOutputs from sarathi.core.datatypes.sequence import ( - Sequence, - SequenceScheduleMetadata, - SequenceMetadata, SamplerOutput, SamplerOutputs, + Sequence, + SequenceMetadata, + SequenceScheduleMetadata, ) -from sarathi.core.datatypes.request_output import RequestOutput -from sarathi.core.datatypes.scheduler_output import SchedulerOutputs from sarathi.core.datatypes.sequence_status import SequenceStatus from sarathi.utils.threading_utils import synchronized @@ -46,8 +46,7 @@ def _resume_seq(self, seq_id: int) -> None: assert seq.is_waiting() or seq.is_paused() seq.set_status(SequenceStatus.RUNNING) - def _on_seq_scheduled( - self, seq_sched_metadata: SequenceScheduleMetadata) -> None: + def _on_seq_scheduled(self, seq_sched_metadata: SequenceScheduleMetadata) -> None: assert seq_sched_metadata.seq_id in self.seq_map self._resume_seq(seq_sched_metadata.seq_id) @@ -75,8 +74,12 @@ def on_schedule( self._on_seq_scheduled(seq_sched_metadata) seq = self.seq_map[seq_sched_metadata.seq_id] seq_metadata_list.append( - SequenceMetadata(seq, self._get_block_table(seq), - seq_sched_metadata.num_prompt_tokens)) + SequenceMetadata( + seq, + self._get_block_table(seq), + seq_sched_metadata.num_prompt_tokens, + ) + ) return ignored_seqs, seq_metadata_list @@ -84,8 +87,9 @@ def on_schedule( def _on_append_token(self, seq: Sequence) -> None: pass - def _process_seq_output(self, seq_id: int, sample: SamplerOutput, - prompt_chunk_len: int) -> None: + def _process_seq_output( + self, seq_id: int, sample: SamplerOutput, prompt_chunk_len: int + ) -> None: assert seq_id in self.seq_map seq = self.seq_map[seq_id] # at this point, the seq should be in paused state @@ -110,8 +114,8 @@ def on_step_completed( sampler_outputs: Optional[SamplerOutputs], ) -> None: for scheduled_seq_metadata, sampler_output in zip( - scheduler_outputs.scheduled_seq_metadata_list, - sampler_outputs): + scheduler_outputs.scheduled_seq_metadata_list, sampler_outputs + ): seq = self.seq_map[scheduled_seq_metadata.seq_id] if seq.is_waiting(): # seq is preempted diff --git a/sarathi/core/sequence_manager/engine_sequence_manager.py b/sarathi/core/sequence_manager/engine_sequence_manager.py index 58bef3c..99ebb08 100644 --- a/sarathi/core/sequence_manager/engine_sequence_manager.py +++ b/sarathi/core/sequence_manager/engine_sequence_manager.py @@ -1,29 +1,30 @@ -from typing import Union, List +from typing import List, Union + +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from sarathi.core.datatypes.sequence import Sequence from sarathi.core.sequence_manager.base_sequence_manager import BaseSequenceManager -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from sarathi.transformers_utils.tokenizer import detokenize_incrementally class EngineSequenceManager(BaseSequenceManager): - def __init__(self, tokenizer: Union[PreTrainedTokenizer, - PreTrainedTokenizerFast]): + def __init__(self, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): super().__init__() self.tokenizer = tokenizer def _decode_seq(self, seq: Sequence) -> None: """Decodes the new token for a sequence.""" - (new_tokens, new_output_text, prefix_offset, - read_offset) = detokenize_incrementally( - self.tokenizer, - all_input_ids=seq.get_token_ids(), - prev_tokens=seq.tokens, - prefix_offset=seq.prefix_offset, - read_offset=seq.read_offset, - skip_special_tokens=True, - ) + (new_tokens, new_output_text, prefix_offset, read_offset) = ( + detokenize_incrementally( + self.tokenizer, + all_input_ids=seq.get_token_ids(), + prev_tokens=seq.tokens, + prefix_offset=seq.prefix_offset, + read_offset=seq.read_offset, + skip_special_tokens=True, + ) + ) if seq.tokens is None: seq.tokens = new_tokens else: diff --git a/sarathi/core/sequence_manager/worker_sequence_manager.py b/sarathi/core/sequence_manager/worker_sequence_manager.py index 4dc3ba4..6b1faa6 100644 --- a/sarathi/core/sequence_manager/worker_sequence_manager.py +++ b/sarathi/core/sequence_manager/worker_sequence_manager.py @@ -1,7 +1,10 @@ from typing import List + +from sarathi.config import BaseSchedulerConfig, CacheConfig +from sarathi.core.block_space_manager.block_space_manager_registry import ( + BlockSpaceManagerRegistry, +) from sarathi.core.datatypes.sequence import Sequence, SequenceScheduleMetadata -from sarathi.core.block_space_manager.block_space_manager_registry import BlockSpaceManagerRegistry -from sarathi.config import CacheConfig, BaseSchedulerConfig from sarathi.core.sequence_manager.base_sequence_manager import BaseSequenceManager @@ -36,8 +39,7 @@ def _preempt_seq(self, seq_id: int) -> None: seq = self.seq_map[seq_id] self.block_manager.free(seq) - def _on_seq_scheduled( - self, seq_sched_metadata: SequenceScheduleMetadata) -> None: + def _on_seq_scheduled(self, seq_sched_metadata: SequenceScheduleMetadata) -> None: super()._on_seq_scheduled(seq_sched_metadata) seq = self.seq_map[seq_sched_metadata.seq_id] diff --git a/sarathi/engine/arg_utils.py b/sarathi/engine/arg_utils.py index 6e34fb9..e17d827 100644 --- a/sarathi/engine/arg_utils.py +++ b/sarathi/engine/arg_utils.py @@ -1,31 +1,40 @@ import dataclasses -from dataclasses import asdict, dataclass import os -from typing import Optional, Tuple, List +from dataclasses import asdict, dataclass +from typing import List, Optional, Tuple import yaml -from sarathi.config import (CacheConfig, MetricsConfig, ModelConfig, - ParallelConfig, BaseSchedulerConfig, - VLLMSchedulerConfig, OrcaSchedulerConfig, - FasterTransformerSchedulerConfig, - SimpleChunkingSchedulerConfig, - SarathiSchedulerConfig, SchedulerType) +from sarathi.config import ( + BaseSchedulerConfig, + CacheConfig, + FasterTransformerSchedulerConfig, + MetricsConfig, + ModelConfig, + OrcaSchedulerConfig, + ParallelConfig, + SarathiSchedulerConfig, + SchedulerType, + SimpleChunkingSchedulerConfig, + VLLMSchedulerConfig, +) @dataclass class EngineArgs: """Arguments for Sarathi engine.""" + model: str replica_id: int = 0 replica_resource_mapping: List[Tuple[str, int]] = dataclasses.field( - default_factory=list) + default_factory=list + ) tokenizer: Optional[str] = None - tokenizer_mode: str = 'auto' + tokenizer_mode: str = "auto" trust_remote_code: bool = False download_dir: Optional[str] = None - load_format: str = 'auto' - dtype: str = 'auto' + load_format: str = "auto" + dtype: str = "auto" seed: int = 0 pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 @@ -33,7 +42,7 @@ class EngineArgs: gpu_memory_utilization: float = 0.85 revision: Optional[str] = None # scheduler parameters - scheduler_type: str = 'sarathi' + scheduler_type: str = "sarathi" max_model_len: Optional[int] = None max_num_seqs: int = 256 # vllm scheduler parameters @@ -47,7 +56,7 @@ class EngineArgs: chunk_schedule_stages: Optional[int] = None # Metrics store parameters write_metrics: bool = True - output_dir: str = '.' + output_dir: str = "." wandb_project: Optional[str] = None wandb_sweep_id: Optional[str] = None wandb_run_id: Optional[str] = None @@ -58,21 +67,19 @@ class EngineArgs: enable_chrome_trace: bool = False enable_request_outputs: bool = False keep_individual_batch_metrics: bool = False - attention_backend: str = 'flash_attention' + attention_backend: str = "flash_attention" def __post_init__(self): if self.tokenizer is None: self.tokenizer = self.model if self.write_metrics: os.makedirs(self.output_dir, exist_ok=True) - with open(f'{self.output_dir}/config.yml', 'w') as f: - yaml.dump(asdict(self), - f, - default_flow_style=False, - sort_keys=False) + with open(f"{self.output_dir}/config.yml", "w") as f: + yaml.dump(asdict(self), f, default_flow_style=False, sort_keys=False) - def _get_scheduler_config(self, model_config: ModelConfig, - num_pipeline_stages: int) -> BaseSchedulerConfig: + def _get_scheduler_config( + self, model_config: ModelConfig, num_pipeline_stages: int + ) -> BaseSchedulerConfig: if self.scheduler_type == SchedulerType.VLLM.name.lower(): scheduler_config = VLLMSchedulerConfig( self.max_num_seqs, @@ -86,8 +93,7 @@ def _get_scheduler_config(self, model_config: ModelConfig, model_config.get_max_model_len(), num_pipeline_stages, ) - elif self.scheduler_type == SchedulerType.FASTER_TRANSFORMER.name.lower( - ): + elif self.scheduler_type == SchedulerType.FASTER_TRANSFORMER.name.lower(): scheduler_config = FasterTransformerSchedulerConfig( self.max_num_seqs, model_config.get_max_model_len(), @@ -113,36 +119,40 @@ def _get_scheduler_config(self, model_config: ModelConfig, self.chunk_size, ) else: - raise ValueError( - f'Unsupported scheduler type: {self.scheduler_type}') + raise ValueError(f"Unsupported scheduler type: {self.scheduler_type}") return scheduler_config def create_engine_configs( self, - ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, BaseSchedulerConfig, - MetricsConfig]: - model_config = ModelConfig(model=self.model, - tokenizer=self.tokenizer, - tokenizer_mode=self.tokenizer_mode, - trust_remote_code=self.trust_remote_code, - download_dir=self.download_dir, - load_format=self.load_format, - dtype=self.dtype, - seed=self.seed, - revision=self.revision, - max_model_len=self.max_model_len, - attention_backend=self.attention_backend) + ) -> Tuple[ + ModelConfig, CacheConfig, ParallelConfig, BaseSchedulerConfig, MetricsConfig + ]: + model_config = ModelConfig( + model=self.model, + tokenizer=self.tokenizer, + tokenizer_mode=self.tokenizer_mode, + trust_remote_code=self.trust_remote_code, + download_dir=self.download_dir, + load_format=self.load_format, + dtype=self.dtype, + seed=self.seed, + revision=self.revision, + max_model_len=self.max_model_len, + attention_backend=self.attention_backend, + ) cache_config = CacheConfig( block_size=self.block_size, - gpu_memory_utilization=self.gpu_memory_utilization) + gpu_memory_utilization=self.gpu_memory_utilization, + ) parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, - replica_resource_mapping=self.replica_resource_mapping) + replica_resource_mapping=self.replica_resource_mapping, + ) scheduler_config = self._get_scheduler_config( - model_config=model_config, - num_pipeline_stages=self.pipeline_parallel_size) + model_config=model_config, num_pipeline_stages=self.pipeline_parallel_size + ) metrics_config = MetricsConfig( replica_id=self.replica_id, write_metrics=self.write_metrics, @@ -159,4 +169,10 @@ def create_engine_configs( keep_individual_batch_metrics=self.keep_individual_batch_metrics, model_num_layers=model_config.get_total_num_layers(), ) - return model_config, cache_config, parallel_config, scheduler_config, metrics_config + return ( + model_config, + cache_config, + parallel_config, + scheduler_config, + metrics_config, + ) diff --git a/sarathi/engine/base_llm_engine.py b/sarathi/engine/base_llm_engine.py index 1ebd83f..8d80b34 100644 --- a/sarathi/engine/base_llm_engine.py +++ b/sarathi/engine/base_llm_engine.py @@ -2,24 +2,28 @@ import math import time from functools import partial -from typing import Any, List, Optional, Tuple, Dict - -from sarathi.config import (CacheConfig, MetricsConfig, ModelConfig, - ParallelConfig, BaseSchedulerConfig) +from typing import Any, Dict, List, Optional, Tuple + +from sarathi.config import ( + BaseSchedulerConfig, + CacheConfig, + MetricsConfig, + ModelConfig, + ParallelConfig, +) +from sarathi.core.datatypes.request_output import RequestOutput +from sarathi.core.datatypes.sampling_params import SamplingParams +from sarathi.core.datatypes.scheduler_output import SchedulerOutputs +from sarathi.core.datatypes.sequence import SamplerOutputs, Sequence, SequenceMetadata from sarathi.core.scheduler.scheduler_registry import SchedulerRegistry +from sarathi.core.sequence_manager.engine_sequence_manager import EngineSequenceManager from sarathi.engine.ray_utils import RayWorker, initialize_cluster, ray from sarathi.logger import init_logger from sarathi.metrics.constants import CpuOperationMetrics from sarathi.metrics.cpu_timer import CpuTimer from sarathi.metrics.metrics_store import MetricsStore -from sarathi.core.datatypes.request_output import RequestOutput -from sarathi.core.datatypes.scheduler_output import SchedulerOutputs -from sarathi.core.datatypes.sampling_params import SamplingParams -from sarathi.core.datatypes.sequence import (SamplerOutputs, Sequence, - SequenceMetadata) from sarathi.transformers_utils.tokenizer import get_tokenizer -from sarathi.core.sequence_manager.engine_sequence_manager import EngineSequenceManager -from sarathi.utils import Counter, unset_cuda_visible_devices, get_ip, get_random_port +from sarathi.utils import Counter, get_ip, get_random_port, unset_cuda_visible_devices logger = init_logger(__name__) @@ -70,7 +74,8 @@ def __init__( f"load_format={model_config.load_format}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " f"pipeline_parallel_size={parallel_config.pipeline_parallel_size}, " - f"seed={model_config.seed})") + f"seed={model_config.seed})" + ) # TODO(woosuk): Print more configs in debug mode. self.model_config = model_config @@ -84,7 +89,8 @@ def __init__( model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode, trust_remote_code=model_config.trust_remote_code, - revision=model_config.revision) + revision=model_config.revision, + ) self.seq_manager = EngineSequenceManager(self.tokenizer) self.seq_counter = Counter() @@ -108,12 +114,14 @@ def __init__( self.mark_initial_memory_profiling_done() # Create the scheduler. - self.scheduler = SchedulerRegistry.get(scheduler_config.type, - scheduler_config, cache_config) + self.scheduler = SchedulerRegistry.get( + scheduler_config.type, scheduler_config, cache_config + ) self._scheduler_timer = CpuTimer(CpuOperationMetrics.SCHEDULE) self._process_model_outputs_timer = CpuTimer( - CpuOperationMetrics.PROCESS_MODEL_OUTPUTS) + CpuOperationMetrics.PROCESS_MODEL_OUTPUTS + ) def _validate_parallel_config(self) -> None: assert self.parallel_config.pipeline_parallel_size == 1 @@ -121,7 +129,10 @@ def _validate_parallel_config(self) -> None: def _get_worker_impl(self): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker - from sarathi.worker.base_worker import BaseWorker # pylint: disable=import-outside-toplevel + from sarathi.worker.base_worker import ( + BaseWorker, # pylint: disable=import-outside-toplevel + ) + return BaseWorker def _init_workers_ray(self, **ray_remote_kwargs): @@ -151,7 +162,8 @@ def _init_workers_ray(self, **ray_remote_kwargs): ) else: worker_class = worker_class.options( - max_concurrency=_MAX_WORKER_CONCURRENCY, ) + max_concurrency=_MAX_WORKER_CONCURRENCY, + ) if rank == 0: if node_ip: @@ -193,7 +205,8 @@ def _init_workers_ray(self, **ray_remote_kwargs): local_rank, rank, distributed_init_method, - )) + ) + ) ray.get(promise) self._run_workers( @@ -223,22 +236,26 @@ def _init_cache(self) -> None: logger.info(f"# GPU blocks: {num_gpu_blocks}") if num_gpu_blocks <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `gpu_memory_utilization` when " - "initializing the engine.") - max_blocks_per_request = math.ceil(self.model_config.max_model_len / - self.cache_config.block_size) + raise ValueError( + "No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine." + ) + max_blocks_per_request = math.ceil( + self.model_config.max_model_len / self.cache_config.block_size + ) if num_gpu_blocks < max_blocks_per_request: raise ValueError( f"Not enough available memory to schedule a request will maximum allowed length {self.model_config.max_model_len}. " f"Need {max_blocks_per_request}, available {num_gpu_blocks} gpu blocks. " - f"Try decreasing `max_batch_size`, `max_model_len`.") + f"Try decreasing `max_batch_size`, `max_model_len`." + ) self.cache_config.num_gpu_blocks = num_gpu_blocks # Initialize the cache. - self._run_workers("init_cache_engine", - cache_config=self.cache_config, - get_all_outputs=True) + self._run_workers( + "init_cache_engine", cache_config=self.cache_config, get_all_outputs=True + ) def _init_worker_map(self) -> None: model_parallel_ranks = self._run_workers( @@ -246,10 +263,7 @@ def _init_worker_map(self) -> None: get_all_outputs=True, ) - self.worker_map = { - mp_rank: i - for i, mp_rank in enumerate(model_parallel_ranks) - } + self.worker_map = {mp_rank: i for i, mp_rank in enumerate(model_parallel_ranks)} def _on_step_completed( self, @@ -275,7 +289,8 @@ def _on_step_completed( batch_end_time=end_time, ) all_request_outputs = self.seq_manager.generate_request_outputs( - ignored_seqs, seq_metadata_list) + ignored_seqs, seq_metadata_list + ) return all_request_outputs def add_request( @@ -312,8 +327,15 @@ def add_request( block_size = self.cache_config.block_size eos_token_id = self.tokenizer.eos_token_id seq_id = next(self.seq_counter) - seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, - eos_token_id, arrival_time, sampling_params) + seq = Sequence( + seq_id, + prompt, + prompt_token_ids, + block_size, + eos_token_id, + arrival_time, + sampling_params, + ) # Add the sequence to the scheduler. self.seq_manager.add_seq(seq) self._run_workers( @@ -350,16 +372,21 @@ def step(self) -> List[RequestOutput]: return [] ignored_seqs, seq_metadata_list = self.seq_manager.on_schedule( - scheduler_outputs) + scheduler_outputs + ) sampler_outputs = self._run_workers( "execute_model", scheduler_outputs=scheduler_outputs, ) - return self._on_step_completed(scheduler_outputs, ignored_seqs, - seq_metadata_list, sampler_outputs, - start_time) + return self._on_step_completed( + scheduler_outputs, + ignored_seqs, + seq_metadata_list, + sampler_outputs, + start_time, + ) def _run_workers( self, @@ -427,8 +454,7 @@ def pull_worker_metrics(self) -> None: def mark_initial_memory_profiling_done(self): self.metrics_store.mark_initial_memory_profiling_done() - self._run_workers("mark_initial_memory_profiling_done", - get_all_outputs=True) + self._run_workers("mark_initial_memory_profiling_done", get_all_outputs=True) def reset_metrics(self) -> None: self.scheduler.reset_state() diff --git a/sarathi/engine/pipeline_parallel_llm_engine.py b/sarathi/engine/pipeline_parallel_llm_engine.py index eace181..8784853 100644 --- a/sarathi/engine/pipeline_parallel_llm_engine.py +++ b/sarathi/engine/pipeline_parallel_llm_engine.py @@ -1,17 +1,22 @@ import time -from threading import Thread, Event -from typing import List -from queue import Queue from dataclasses import dataclass +from queue import Queue +from threading import Event, Thread +from typing import List -from sarathi.config import (CacheConfig, MetricsConfig, ModelConfig, - ParallelConfig, BaseSchedulerConfig) -from sarathi.logger import init_logger +from sarathi.config import ( + BaseSchedulerConfig, + CacheConfig, + MetricsConfig, + ModelConfig, + ParallelConfig, +) from sarathi.core.datatypes.request_output import RequestOutput from sarathi.core.datatypes.scheduler_output import SchedulerOutputs from sarathi.core.datatypes.sequence import SequenceMetadata -from sarathi.utils.threading_utils import exit_on_error from sarathi.engine.base_llm_engine import BaseLLMEngine +from sarathi.logger import init_logger +from sarathi.utils.threading_utils import exit_on_error logger = init_logger(__name__) @@ -56,8 +61,13 @@ def __init__( scheduler_config: BaseSchedulerConfig, metrics_config: MetricsConfig, ) -> None: - super().__init__(model_config, cache_config, parallel_config, - scheduler_config, metrics_config) + super().__init__( + model_config, + cache_config, + parallel_config, + scheduler_config, + metrics_config, + ) # Create the request queue. self.has_started_execution_loops = False self.scheduler_output_queue = Queue() @@ -66,10 +76,12 @@ def __init__( self.microbatch_watch_event = Event() self.schedule_thread = Thread(target=self._schedule_loop, daemon=True) self.microbatch_watch_thread = Thread( - target=self._microbatch_watch_loop, daemon=True) + target=self._microbatch_watch_loop, daemon=True + ) self.output_thread = Thread(target=self._output_loop, daemon=True) - self.scheduler_timer_thread = Thread(target=self._scheduler_timer_loop, - daemon=True) + self.scheduler_timer_thread = Thread( + target=self._scheduler_timer_loop, daemon=True + ) def _validate_parallel_config(self) -> None: assert self.parallel_config.pipeline_parallel_size > 1 @@ -93,6 +105,7 @@ def _get_worker_impl(self): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker from sarathi.worker.pipeline_parallel_worker import PipelineParallelWorker + return PipelineParallelWorker @exit_on_error @@ -109,7 +122,8 @@ def _schedule_loop(self) -> None: continue ignored_seqs, seq_metadata_list = self.seq_manager.on_schedule( - scheduler_outputs) + scheduler_outputs + ) self.scheduler_output_queue.put( ScheduleStageOutputs( @@ -117,7 +131,8 @@ def _schedule_loop(self) -> None: seq_metadata_list, scheduler_outputs, start_time, - )) + ) + ) if not scheduler_outputs.is_empty(): self.microbatch_watch_event.set() @@ -128,8 +143,7 @@ def _schedule_loop(self) -> None: ) end_time = time.perf_counter() - self.metrics_store.on_schedule(seq_metadata_list, start_time, - end_time) + self.metrics_store.on_schedule(seq_metadata_list, start_time, end_time) @exit_on_error def _microbatch_watch_loop(self) -> None: @@ -149,8 +163,10 @@ def _output_loop(self) -> None: scheduler_stage_output = self.scheduler_output_queue.get() sampler_outputs = self._run_worker( - (0, self.parallel_config.pipeline_parallel_size - - 1), # TP rank zero for last pipeline stage + ( + 0, + self.parallel_config.pipeline_parallel_size - 1, + ), # TP rank zero for last pipeline stage "get_output", ) @@ -164,8 +180,10 @@ def _output_loop(self) -> None: all_request_outputs = self._on_step_completed( scheduler_stage_output.scheduler_outputs, scheduler_stage_output.ignored_seqs, - scheduler_stage_output.seq_metadata_list, sampler_outputs, - scheduler_stage_output.start_time) + scheduler_stage_output.seq_metadata_list, + sampler_outputs, + scheduler_stage_output.start_time, + ) self.schedule_event.set() self.output_queue.put(all_request_outputs) diff --git a/sarathi/engine/ray_utils.py b/sarathi/engine/ray_utils.py index 6adad47..9b3f472 100644 --- a/sarathi/engine/ray_utils.py +++ b/sarathi/engine/ray_utils.py @@ -8,7 +8,7 @@ try: import ray - class RayWorker(): + class RayWorker: """Ray wrapper for sarathi.worker.Worker, allowing Worker to be lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES.""" @@ -16,6 +16,7 @@ def __init__(self, init_cached_hf_modules=False) -> None: if init_cached_hf_modules: # pylint: disable=import-outside-toplevel from transformers.dynamic_module_utils import init_hf_modules + init_hf_modules() unset_cuda_visible_devices() self.worker = None @@ -31,14 +32,18 @@ def execute_method(self, method, *args, **kwargs): return executor(*args, **kwargs) except ImportError as e: - logger.warning(f"Failed to import Ray with {e!r}. " - "For distributed inference, please install Ray with " - "`pip install ray pandas pyarrow`.") + logger.warning( + f"Failed to import Ray with {e!r}. " + "For distributed inference, please install Ray with " + "`pip install ray pandas pyarrow`." + ) ray = None RayWorker = None # pylint: disable=invalid-name -def initialize_cluster(ray_address: Optional[str] = None, ): +def initialize_cluster( + ray_address: Optional[str] = None, +): """Initialize the distributed cluster probably with Ray. Args: @@ -47,7 +52,7 @@ def initialize_cluster(ray_address: Optional[str] = None, ): """ if ray is None: raise ImportError( - "Ray is not installed. Please install Ray to use distributed " - "serving.") + "Ray is not installed. Please install Ray to use distributed " "serving." + ) # Connect to a ray cluster. ray.init(address=ray_address, ignore_reinit_error=True) diff --git a/sarathi/metrics/cdf_sketch.py b/sarathi/metrics/cdf_sketch.py index e8671b6..1635d97 100644 --- a/sarathi/metrics/cdf_sketch.py +++ b/sarathi/metrics/cdf_sketch.py @@ -3,9 +3,7 @@ import numpy as np import pandas as pd import plotly_express as px - import wandb - from ddsketch.ddsketch import DDSketch logger = logging.getLogger(__name__) @@ -82,32 +80,32 @@ def print_distribution_stats(self, plot_name: str) -> None: f" 99th percentile: {self.sketch.get_quantile_value(0.99)}" f" 99.9th percentile: {self.sketch.get_quantile_value(0.999)}" f" count: {self.sketch._count}" - f" sum: {self.sketch.sum}") + f" sum: {self.sketch.sum}" + ) if wandb.run: wandb.log( { - f"{plot_name}_min": - self.sketch._min, - f"{plot_name}_max": - self.sketch._max, - f"{plot_name}_mean": - self.sketch.avg, - f"{plot_name}_25th_percentile": - self.sketch.get_quantile_value(0.25), - f"{plot_name}_median": - self.sketch.get_quantile_value(0.5), - f"{plot_name}_75th_percentile": - self.sketch.get_quantile_value(0.75), - f"{plot_name}_95th_percentile": - self.sketch.get_quantile_value(0.95), - f"{plot_name}_99th_percentile": - self.sketch.get_quantile_value(0.99), - f"{plot_name}_99.9th_percentile": - self.sketch.get_quantile_value(0.999), - f"{plot_name}_count": - self.sketch.count, - f"{plot_name}_sum": - self.sketch.sum, + f"{plot_name}_min": self.sketch._min, + f"{plot_name}_max": self.sketch._max, + f"{plot_name}_mean": self.sketch.avg, + f"{plot_name}_25th_percentile": self.sketch.get_quantile_value( + 0.25 + ), + f"{plot_name}_median": self.sketch.get_quantile_value(0.5), + f"{plot_name}_75th_percentile": self.sketch.get_quantile_value( + 0.75 + ), + f"{plot_name}_95th_percentile": self.sketch.get_quantile_value( + 0.95 + ), + f"{plot_name}_99th_percentile": self.sketch.get_quantile_value( + 0.99 + ), + f"{plot_name}_99.9th_percentile": self.sketch.get_quantile_value( + 0.999 + ), + f"{plot_name}_count": self.sketch.count, + f"{plot_name}_sum": self.sketch.sum, }, step=0, ) @@ -116,24 +114,16 @@ def to_df(self) -> pd.DataFrame: # get quantiles at 1% intervals quantiles = np.linspace(0, 1, self._num_quantiles_in_df) # get quantile values - quantile_values = [ - self.sketch.get_quantile_value(q) for q in quantiles - ] + quantile_values = [self.sketch.get_quantile_value(q) for q in quantiles] # create dataframe - df = pd.DataFrame({ - "cdf": quantiles, - self.metric_name: quantile_values - }) + df = pd.DataFrame({"cdf": quantiles, self.metric_name: quantile_values}) return df def _save_df(self, df: pd.DataFrame, path: str, plot_name: str) -> None: df.to_csv(f"{path}/{plot_name}.csv", index=False) - def plot_cdf(self, - path: str, - plot_name: str, - x_axis_label: str = None) -> None: + def plot_cdf(self, path: str, plot_name: str, x_axis_label: str = None) -> None: if self.sketch._count == 0: return @@ -145,23 +135,19 @@ def plot_cdf(self, self.print_distribution_stats(plot_name) - fig = px.line(df, - x=self.metric_name, - y="cdf", - markers=True, - labels={"x": x_axis_label}) + fig = px.line( + df, x=self.metric_name, y="cdf", markers=True, labels={"x": x_axis_label} + ) fig.update_traces(marker=dict(color="red", size=2)) if wandb.run: wandb_df = df.copy() # rename the self.metric_name column to x_axis_label - wandb_df = wandb_df.rename( - columns={self.metric_name: x_axis_label}) + wandb_df = wandb_df.rename(columns={self.metric_name: x_axis_label}) wandb.log( { - f"{plot_name}_cdf": - wandb.plot.line( + f"{plot_name}_cdf": wandb.plot.line( wandb.Table(dataframe=wandb_df), "cdf", x_axis_label, diff --git a/sarathi/metrics/constants.py b/sarathi/metrics/constants.py index eb08ff3..d8f407e 100644 --- a/sarathi/metrics/constants.py +++ b/sarathi/metrics/constants.py @@ -49,24 +49,31 @@ class SequenceMetricsTimeDistributions(enum.Enum): REQUEST_PREEMPTION_TIME = "request_preemption_time" REQUEST_SCHEDULING_DELAY = "request_scheduling_delay" REQUEST_EXECUTION_PLUS_PREEMPTION_TIME = "request_execution_plus_preemption_time" - REQUEST_EXECUTION_PLUS_PREEMPTION_TIME_NORMALIZED = "request_execution_plus_preemption_time_normalized" + REQUEST_EXECUTION_PLUS_PREEMPTION_TIME_NORMALIZED = ( + "request_execution_plus_preemption_time_normalized" + ) PREFILL_TIME_E2E = "prefill_e2e_time" PREFILL_TIME_E2E_NORMALIZED = "prefill_e2e_time_normalized" - PREFILL_TIME_E2E_PIECEWISE_NORMALIZED = ( - "prefill_e2e_time_piecewise_normalized") + PREFILL_TIME_E2E_PIECEWISE_NORMALIZED = "prefill_e2e_time_piecewise_normalized" PREFILL_TIME_EXECUTION_PLUS_PREEMPTION = "prefill_time_execution_plus_preemption" PREFILL_TIME_EXECUTION_PLUS_PREEMPTION_NORMALIZED = ( - "prefill_time_execution_plus_preemption_normalized") + "prefill_time_execution_plus_preemption_normalized" + ) DECODE_TIME_EXECUTION_PLUS_PREEMPTION_NORMALIZED = ( - "decode_time_execution_plus_preemption_normalized") + "decode_time_execution_plus_preemption_normalized" + ) class TokenMetricsTimeDistribution(enum.Enum): - DECODE_TOKEN_EXECUTION_PLUS_PREEMPTION_TIME = "decode_token_execution_plus_preemption_time" + DECODE_TOKEN_EXECUTION_PLUS_PREEMPTION_TIME = ( + "decode_token_execution_plus_preemption_time" + ) class TokenMetricsTimeList(enum.Enum): - DECODE_TOKEN_EXECUTION_PLUS_PREEMPTION_TIME_LIST = "decode_token_execution_plus_preemption_time_list" + DECODE_TOKEN_EXECUTION_PLUS_PREEMPTION_TIME_LIST = ( + "decode_token_execution_plus_preemption_time_list" + ) class SequenceMetricsHistogram(enum.Enum): diff --git a/sarathi/metrics/cpu_timer.py b/sarathi/metrics/cpu_timer.py index d7bb788..98ff36b 100644 --- a/sarathi/metrics/cpu_timer.py +++ b/sarathi/metrics/cpu_timer.py @@ -1,8 +1,8 @@ -import torch - from time import perf_counter from typing import Optional +import torch + from sarathi.metrics.constants import CpuOperationMetrics from sarathi.metrics.metrics_store import MetricsStore @@ -14,7 +14,8 @@ def __init__(self, name: CpuOperationMetrics, rank: Optional[int] = None): self.start_time = None self.metrics_store = MetricsStore() self.disabled = not self.metrics_store.is_op_enabled( - metric_name=self.name, rank=rank) + metric_name=self.name, rank=rank + ) def __enter__(self): if self.disabled: @@ -29,6 +30,5 @@ def __exit__(self, *_): torch.cuda.synchronize() self.metrics_store.push_cpu_operation_metrics( - self.name, - (perf_counter() - self.start_time) * 1e3 # convert to ms + self.name, (perf_counter() - self.start_time) * 1e3 # convert to ms ) diff --git a/sarathi/metrics/cuda_timer.py b/sarathi/metrics/cuda_timer.py index 5b8466f..4fb5a20 100644 --- a/sarathi/metrics/cuda_timer.py +++ b/sarathi/metrics/cuda_timer.py @@ -1,21 +1,25 @@ +from typing import Optional + import torch -from typing import Optional from sarathi.metrics.constants import OperationMetrics from sarathi.metrics.metrics_store import MetricsStore class CudaTimer: - def __init__(self, - name: OperationMetrics, - layer_id: Optional[int] = None, - rank: Optional[int] = None): + def __init__( + self, + name: OperationMetrics, + layer_id: Optional[int] = None, + rank: Optional[int] = None, + ): self.name = name self.metrics_store = MetricsStore() self.layer_id = layer_id self.disabled = (name is None) or not self.metrics_store.is_op_enabled( - metric_name=self.name, layer_id=layer_id, rank=rank) + metric_name=self.name, layer_id=layer_id, rank=rank + ) if self.disabled: return @@ -42,8 +46,7 @@ def __enter__(self): return self def handle_trace(self, trace): - total_cuda_time = sum( - [e.cuda_time_total for e in trace.key_averages()]) + total_cuda_time = sum([e.cuda_time_total for e in trace.key_averages()]) self.metrics_store.push_operation_metrics( self.name, @@ -58,6 +61,7 @@ def __exit__(self, *args): self.end_event = torch.cuda.Event(enable_timing=True) self.end_event.record() self.metrics_store.push_operation_metrics_events( - self.name, self.start_event, self.end_event) + self.name, self.start_event, self.end_event + ) else: self.profiler.__exit__(*args) diff --git a/sarathi/metrics/data_series.py b/sarathi/metrics/data_series.py index c8b38c3..c068007 100644 --- a/sarathi/metrics/data_series.py +++ b/sarathi/metrics/data_series.py @@ -1,10 +1,10 @@ import logging +from collections import defaultdict, deque + import pandas as pd import plotly_express as px import wandb -from collections import defaultdict, deque - logger = logging.getLogger(__name__) @@ -25,7 +25,9 @@ def __init__( # to aid incremental updates to y datapoints self._last_data_y = 0 - def consolidate(self, ): + def consolidate( + self, + ): res = defaultdict(list) for x, y in self.data_series: res[x].append(y) @@ -33,8 +35,7 @@ def consolidate(self, ): # sort by x self.data_series = sorted(self.data_series, key=lambda x: x[0]) - self._last_data_y = self.data_series[-1][1] if len( - self.data_series) else 0 + self._last_data_y = self.data_series[-1][1] if len(self.data_series) else 0 def merge(self, other: "DataSeries"): if len(other) == 0: @@ -101,8 +102,7 @@ def _peek_y(self): # convert list of x, y datapoints to a pandas dataframe def to_df(self): - return pd.DataFrame(self.data_series, - columns=[self.x_name, self.y_name]) + return pd.DataFrame(self.data_series, columns=[self.x_name, self.y_name]) # add a new x, y datapoint as an incremental (delta) update to # recently collected y datapoint @@ -111,10 +111,9 @@ def put_delta(self, data_x: float, data_y_delta: float) -> None: data_y = last_data_y + data_y_delta self.put(data_x, data_y) - def print_series_stats(self, - df: pd.DataFrame, - plot_name: str, - y_name: str = None) -> None: + def print_series_stats( + self, df: pd.DataFrame, plot_name: str, y_name: str = None + ) -> None: if len(self.data_series) == 0: return @@ -122,10 +121,12 @@ def print_series_stats(self, if y_name is None: y_name = self.y_name - logger.info(f"{plot_name}: {y_name} stats:" - f" min: {df[y_name].min()}," - f" max: {df[y_name].max()}," - f" mean: {df[y_name].mean()},") + logger.info( + f"{plot_name}: {y_name} stats:" + f" min: {df[y_name].min()}," + f" max: {df[y_name].max()}," + f" mean: {df[y_name].mean()}," + ) if wandb.run: wandb.log( { @@ -136,10 +137,9 @@ def print_series_stats(self, step=0, ) - def print_distribution_stats(self, - df: pd.DataFrame, - plot_name: str, - y_name: str = None) -> None: + def print_distribution_stats( + self, df: pd.DataFrame, plot_name: str, y_name: str = None + ) -> None: if len(self.data_series) == 0: return @@ -147,14 +147,16 @@ def print_distribution_stats(self, if y_name is None: y_name = self.y_name - logger.info(f"{plot_name}: {y_name} stats:" - f" min: {df[y_name].min()}," - f" max: {df[y_name].max()}," - f" mean: {df[y_name].mean()}," - f" median: {df[y_name].median()}," - f" 95th percentile: {df[y_name].quantile(0.95)}," - f" 99th percentile: {df[y_name].quantile(0.99)}" - f" 99.9th percentile: {df[y_name].quantile(0.999)}") + logger.info( + f"{plot_name}: {y_name} stats:" + f" min: {df[y_name].min()}," + f" max: {df[y_name].max()}," + f" mean: {df[y_name].mean()}," + f" median: {df[y_name].median()}," + f" 95th percentile: {df[y_name].quantile(0.95)}," + f" 99th percentile: {df[y_name].quantile(0.99)}" + f" 99.9th percentile: {df[y_name].quantile(0.999)}" + ) if wandb.run: wandb.log( { @@ -164,8 +166,7 @@ def print_distribution_stats(self, f"{plot_name}_median": df[y_name].median(), f"{plot_name}_95th_percentile": df[y_name].quantile(0.95), f"{plot_name}_99th_percentile": df[y_name].quantile(0.99), - f"{plot_name}_99.9th_percentile": - df[y_name].quantile(0.999), + f"{plot_name}_99.9th_percentile": df[y_name].quantile(0.999), }, step=0, ) @@ -202,11 +203,9 @@ def plot_step( self.print_series_stats(df, plot_name) # change marker color to red - fig = px.line(df, - x=self.x_name, - y=self.y_name, - markers=True, - labels={"x": y_axis_label}) + fig = px.line( + df, x=self.x_name, y=self.y_name, markers=True, labels={"x": y_axis_label} + ) fig.update_traces(marker=dict(color="red", size=2)) if wandb.run: @@ -216,8 +215,7 @@ def plot_step( wandb.log( { - f"{plot_name}_step": - wandb.plot.line( + f"{plot_name}_step": wandb.plot.line( wandb.Table(dataframe=wandb_df), self.x_name, y_axis_label, @@ -230,10 +228,7 @@ def plot_step( fig.write_image(f"{path}/{plot_name}.png") self._save_df(df, path, plot_name) - def plot_cdf(self, - path: str, - plot_name: str, - y_axis_label: str = None) -> None: + def plot_cdf(self, path: str, plot_name: str, y_axis_label: str = None) -> None: if len(self.data_series) == 0: return @@ -249,11 +244,9 @@ def plot_cdf(self, # sort by cdf df = df.sort_values(by=["cdf"]) - fig = px.line(df, - x=self.y_name, - y="cdf", - markers=True, - labels={"x": y_axis_label}) + fig = px.line( + df, x=self.y_name, y="cdf", markers=True, labels={"x": y_axis_label} + ) fig.update_traces(marker=dict(color="red", size=2)) if wandb.run: @@ -263,8 +256,7 @@ def plot_cdf(self, wandb.log( { - f"{plot_name}_cdf": - wandb.plot.line( + f"{plot_name}_cdf": wandb.plot.line( wandb.Table(dataframe=wandb_df), "cdf", y_axis_label, @@ -290,8 +282,7 @@ def plot_histogram(self, path: str, plot_name: str) -> None: # wandb histogram is highly inaccurate so we need to generate the histogram # ourselves and then use wandb bar chart - histogram_df = df[self.y_name].value_counts(bins=25, - sort=False).sort_index() + histogram_df = df[self.y_name].value_counts(bins=25, sort=False).sort_index() histogram_df = histogram_df.reset_index() histogram_df.columns = ["Bins", "count"] histogram_df["Bins"] = histogram_df["Bins"].apply(lambda x: x.mid) @@ -304,8 +295,7 @@ def plot_histogram(self, path: str, plot_name: str) -> None: if wandb.run: wandb.log( { - f"{plot_name}_histogram": - wandb.plot.bar( + f"{plot_name}_histogram": wandb.plot.bar( wandb.Table(dataframe=histogram_df), "Bins", "Percentage", # wandb plots are horizontal diff --git a/sarathi/metrics/metrics_store.py b/sarathi/metrics/metrics_store.py index 423cdba..e18ec56 100644 --- a/sarathi/metrics/metrics_store.py +++ b/sarathi/metrics/metrics_store.py @@ -1,34 +1,34 @@ -from dataclasses import asdict -import os import json +import logging +import os +import zipfile from copy import deepcopy +from dataclasses import asdict from functools import reduce -from typing import Dict, List, Optional, Tuple, Any, Union -import zipfile +from typing import Any, Dict, List, Optional, Tuple, Union import pandas as pd import plotly.express as px import torch import wandb -import logging from sarathi.config import MetricsConfig +from sarathi.core.datatypes.request_output import RequestOutput +from sarathi.core.datatypes.scheduler_output import SchedulerOutputs +from sarathi.core.datatypes.sequence import Sequence, SequenceMetadata +from sarathi.metrics.cdf_sketch import CDFSketch from sarathi.metrics.constants import ( - TokenMetricsTimeDistribution, + BatchMetricsCountDistribution, + BatchMetricsTimeDistribution, + CompletionMetricsTimeSeries, CpuOperationMetrics, OperationMetrics, - SequenceMetricsTimeDistributions, SequenceMetricsHistogram, - CompletionMetricsTimeSeries, - BatchMetricsCountDistribution, - BatchMetricsTimeDistribution, + SequenceMetricsTimeDistributions, + TokenMetricsTimeDistribution, TokenMetricsTimeList, ) from sarathi.metrics.data_series import DataSeries -from sarathi.metrics.cdf_sketch import CDFSketch -from sarathi.core.datatypes.request_output import RequestOutput -from sarathi.core.datatypes.scheduler_output import SchedulerOutputs -from sarathi.core.datatypes.sequence import Sequence, SequenceMetadata from sarathi.utils.singleton import Singleton logger = logging.getLogger(__name__) @@ -88,16 +88,20 @@ def __init__(self, metrics_config: MetricsConfig): self._enable_cpu_op_level_metrics = metrics_config.enable_cpu_op_level_metrics self._enable_chrome_trace = metrics_config.enable_chrome_trace self._enable_request_outputs = metrics_config.enable_request_outputs - self._keep_individual_batch_metrics = metrics_config.keep_individual_batch_metrics + self._keep_individual_batch_metrics = ( + metrics_config.keep_individual_batch_metrics + ) self._model_num_layers = metrics_config.model_num_layers self.reset() self._init_wandb() - def is_op_enabled(self, - metric_name: Any, - rank: Optional[int] = None, - layer_id: Optional[int] = None) -> bool: + def is_op_enabled( + self, + metric_name: Any, + rank: Optional[int] = None, + layer_id: Optional[int] = None, + ) -> bool: if self.disabled: return False @@ -107,15 +111,15 @@ def is_op_enabled(self, if not self._enable_cpu_op_level_metrics: return False if metric_name in [ - CpuOperationMetrics.SCHEDULE, - CpuOperationMetrics.PROCESS_MODEL_OUTPUTS, + CpuOperationMetrics.SCHEDULE, + CpuOperationMetrics.PROCESS_MODEL_OUTPUTS, ]: assert rank is None return True elif metric_name in [ - CpuOperationMetrics.PREPARE_INPUTS_E2E, - CpuOperationMetrics.MODEL_EXECUTION_E2E, - CpuOperationMetrics.SAMPLER_E2E, + CpuOperationMetrics.PREPARE_INPUTS_E2E, + CpuOperationMetrics.MODEL_EXECUTION_E2E, + CpuOperationMetrics.SAMPLER_E2E, ]: return rank == 0 raise ValueError(f"Unknown metric name: {metric_name}") @@ -123,7 +127,8 @@ def is_op_enabled(self, def reset(self): # Initialise request metrics self.seq_metrics_time_distributions: Dict[ - SequenceMetricsTimeDistributions, DataSeries] = {} + SequenceMetricsTimeDistributions, DataSeries + ] = {} for metric_name in SequenceMetricsTimeDistributions: self.seq_metrics_time_distributions[metric_name] = DataSeries( REQUEST_ID_STR, @@ -131,7 +136,8 @@ def reset(self): ) self.token_metrics_time_distribution: Dict[ - TokenMetricsTimeDistribution, CDFSketch] = {} + TokenMetricsTimeDistribution, CDFSketch + ] = {} for metric_name in TokenMetricsTimeDistribution: self.token_metrics_time_distribution[metric_name] = CDFSketch( metric_name.value, @@ -139,16 +145,14 @@ def reset(self): num_quantiles_in_df=1001, ) - self.token_metrics_time_list: Dict[TokenMetricsTimeList, - DataSeries] = {} + self.token_metrics_time_list: Dict[TokenMetricsTimeList, DataSeries] = {} for metric_name in TokenMetricsTimeList: self.token_metrics_time_list[metric_name] = DataSeries( DECODE_TOKEN_ID_STR, metric_name.value, ) - self.seq_metrics_histogram: Dict[SequenceMetricsHistogram, - DataSeries] = {} + self.seq_metrics_histogram: Dict[SequenceMetricsHistogram, DataSeries] = {} for metric_name in SequenceMetricsHistogram: self.seq_metrics_histogram[metric_name] = DataSeries( REQUEST_ID_STR, @@ -160,30 +164,43 @@ def reset(self): # Initialise batch metrics self.batch_metrics_count_distribution: Dict[ - BatchMetricsCountDistribution, Union[DataSeries, CDFSketch]] = {} + BatchMetricsCountDistribution, Union[DataSeries, CDFSketch] + ] = {} for metric_name in BatchMetricsCountDistribution: - self.batch_metrics_count_distribution[metric_name] = DataSeries( - BATCH_ID_STR, - metric_name.value, - ) if self._keep_individual_batch_metrics else CDFSketch( - metric_name.value, ) + self.batch_metrics_count_distribution[metric_name] = ( + DataSeries( + BATCH_ID_STR, + metric_name.value, + ) + if self._keep_individual_batch_metrics + else CDFSketch( + metric_name.value, + ) + ) self.batch_metrics_time_distribution: Dict[ - BatchMetricsTimeDistribution, Union[DataSeries, CDFSketch]] = {} + BatchMetricsTimeDistribution, Union[DataSeries, CDFSketch] + ] = {} for metric_name in BatchMetricsTimeDistribution: - self.batch_metrics_time_distribution[metric_name] = DataSeries( - BATCH_ID_STR, - metric_name.value, - ) if self._keep_individual_batch_metrics else CDFSketch( - metric_name.value, ) + self.batch_metrics_time_distribution[metric_name] = ( + DataSeries( + BATCH_ID_STR, + metric_name.value, + ) + if self._keep_individual_batch_metrics + else CDFSketch( + metric_name.value, + ) + ) # to measure the time wasted between the last batch and the next batch self._last_batch_end_time = None self._next_batch_id = 0 # Initialise completion metrics - self.completion_metrics_time_series: Dict[CompletionMetricsTimeSeries, - DataSeries] = {} + self.completion_metrics_time_series: Dict[ + CompletionMetricsTimeSeries, DataSeries + ] = {} for metric_name in CompletionMetricsTimeSeries: self.completion_metrics_time_series[metric_name] = DataSeries( TIME_STR, @@ -191,34 +208,44 @@ def reset(self): ) self.operation_metrics: Dict[OperationMetrics, CDFSketch] = {} - self.operation_metrics_per_batch: Dict[OperationMetrics, - DataSeries] = {} + self.operation_metrics_per_batch: Dict[OperationMetrics, DataSeries] = {} self.operation_metrics_per_batch_events: Dict[ - OperationMetrics, List[Tuple[torch.cuda.Event]]] = {} + OperationMetrics, List[Tuple[torch.cuda.Event]] + ] = {} for metric_name in OperationMetrics: self.operation_metrics[metric_name] = CDFSketch( - metric_name.value, ) + metric_name.value, + ) self.operation_metrics_per_batch[metric_name] = DataSeries( BATCH_ID_STR, metric_name.value, ) self.operation_metrics_per_batch_events[metric_name] = [] - self.cpu_operation_metrics: Dict[CpuOperationMetrics, - Union[CDFSketch, DataSeries]] = {} + self.cpu_operation_metrics: Dict[ + CpuOperationMetrics, Union[CDFSketch, DataSeries] + ] = {} for metric_name in CpuOperationMetrics: - self.cpu_operation_metrics[metric_name] = DataSeries( - BATCH_ID_STR, - metric_name.value, - ) if self._keep_individual_batch_metrics else CDFSketch( - metric_name.value, ) + self.cpu_operation_metrics[metric_name] = ( + DataSeries( + BATCH_ID_STR, + metric_name.value, + ) + if self._keep_individual_batch_metrics + else CDFSketch( + metric_name.value, + ) + ) self.chrome_trace: List[Dict[str, Any]] = [] self.requests_outputs: List[RequestOutput] = [] def _init_wandb(self): - if (not self.should_write_metrics or not self._wandb_project - or not self._wandb_group): + if ( + not self.should_write_metrics + or not self._wandb_project + or not self._wandb_group + ): return logger.info( @@ -226,8 +253,7 @@ def _init_wandb(self): f", sweep_id: {self._wandb_sweep_id}, run_id: {self._wandb_run_id}" ) if self._wandb_sweep_id or self._wandb_run_id: - logger.warn( - "wandb_sweep_id and wandb_run_id are not supported yet.") + logger.warn("wandb_sweep_id and wandb_run_id are not supported yet.") wandb.init( project=self._wandb_project, @@ -254,14 +280,15 @@ def _get_seq_id(self, seq_id: str) -> str: @if_write_metrics def on_request_arrival(self, seq: Sequence) -> None: self.completion_metrics_time_series[ - CompletionMetricsTimeSeries.REQUEST_ARRIVAL].put( - seq.state.arrived_at, 1) + CompletionMetricsTimeSeries.REQUEST_ARRIVAL + ].put(seq.state.arrived_at, 1) if self._last_request_arrived_at is not None: self.seq_metrics_histogram[ - SequenceMetricsHistogram.REQUEST_INTER_ARRIVAL_DELAY].put( - self._get_seq_id(seq.seq_id), - seq.state.arrived_at - self._last_request_arrived_at, - ) + SequenceMetricsHistogram.REQUEST_INTER_ARRIVAL_DELAY + ].put( + self._get_seq_id(seq.seq_id), + seq.state.arrived_at - self._last_request_arrived_at, + ) self._last_request_arrived_at = seq.state.arrived_at @if_write_metrics @@ -271,12 +298,11 @@ def _on_request_end(self, seq: Sequence) -> None: # log request outputs and completion metrics regardless of whether the request is ignored or not self.completion_metrics_time_series[ - CompletionMetricsTimeSeries.REQUEST_COMPLETION].put( - seq.state.completed_at, 1) - self.seq_metrics_histogram[ - SequenceMetricsHistogram.REQUEST_NUM_IGNORED].put( - self._get_seq_id(seq.seq_id), - int(seq.state.is_ignore_finished)) + CompletionMetricsTimeSeries.REQUEST_COMPLETION + ].put(seq.state.completed_at, 1) + self.seq_metrics_histogram[SequenceMetricsHistogram.REQUEST_NUM_IGNORED].put( + self._get_seq_id(seq.seq_id), int(seq.state.is_ignore_finished) + ) if seq.state.is_ignore_finished: # do not log metrics for ignored requests, they can skew the results @@ -286,95 +312,93 @@ def _on_request_end(self, seq: Sequence) -> None: self.requests_outputs.append(RequestOutput.from_seq(seq)) # first log all the histograms - self.seq_metrics_histogram[ - SequenceMetricsHistogram.REQUEST_NUM_TOKENS].put( - self._get_seq_id(seq.seq_id), seq.state.num_total_tokens) - self.seq_metrics_histogram[ - SequenceMetricsHistogram.REQUEST_PREFILL_TOKENS].put( - self._get_seq_id(seq.seq_id), seq.state.num_prompt_tokens) - self.seq_metrics_histogram[ - SequenceMetricsHistogram.REQUEST_DECODE_TOKENS].put( - self._get_seq_id(seq.seq_id), seq.state.num_output_tokens) - self.seq_metrics_histogram[ - SequenceMetricsHistogram.REQUEST_PD_RATIO].put( - self._get_seq_id(seq.seq_id), - seq.state.num_prompt_tokens / seq.state.num_output_tokens) - self.seq_metrics_histogram[ - SequenceMetricsHistogram.REQUEST_NUM_RESTARTS].put( - self._get_seq_id(seq.seq_id), seq.state.num_restarts) - self.seq_metrics_histogram[ - SequenceMetricsHistogram.REQUEST_NUM_PAUSES].put( - self._get_seq_id(seq.seq_id), seq.state.num_pauses) + self.seq_metrics_histogram[SequenceMetricsHistogram.REQUEST_NUM_TOKENS].put( + self._get_seq_id(seq.seq_id), seq.state.num_total_tokens + ) + self.seq_metrics_histogram[SequenceMetricsHistogram.REQUEST_PREFILL_TOKENS].put( + self._get_seq_id(seq.seq_id), seq.state.num_prompt_tokens + ) + self.seq_metrics_histogram[SequenceMetricsHistogram.REQUEST_DECODE_TOKENS].put( + self._get_seq_id(seq.seq_id), seq.state.num_output_tokens + ) + self.seq_metrics_histogram[SequenceMetricsHistogram.REQUEST_PD_RATIO].put( + self._get_seq_id(seq.seq_id), + seq.state.num_prompt_tokens / seq.state.num_output_tokens, + ) + self.seq_metrics_histogram[SequenceMetricsHistogram.REQUEST_NUM_RESTARTS].put( + self._get_seq_id(seq.seq_id), seq.state.num_restarts + ) + self.seq_metrics_histogram[SequenceMetricsHistogram.REQUEST_NUM_PAUSES].put( + self._get_seq_id(seq.seq_id), seq.state.num_pauses + ) # then log all the time distributions self.seq_metrics_time_distributions[ - SequenceMetricsTimeDistributions.REQUEST_E2E_TIME].put( - self._get_seq_id(seq.seq_id), seq.state.e2e_time) + SequenceMetricsTimeDistributions.REQUEST_E2E_TIME + ].put(self._get_seq_id(seq.seq_id), seq.state.e2e_time) self.seq_metrics_time_distributions[ - SequenceMetricsTimeDistributions.REQUEST_E2E_TIME_NORMALIZED].put( - self._get_seq_id(seq.seq_id), seq.state.e2e_time_normalized) + SequenceMetricsTimeDistributions.REQUEST_E2E_TIME_NORMALIZED + ].put(self._get_seq_id(seq.seq_id), seq.state.e2e_time_normalized) self.seq_metrics_time_distributions[ - SequenceMetricsTimeDistributions. - REQUEST_E2E_TIME_PIECEWISE_NORMALIZED].put( - self._get_seq_id(seq.seq_id), - seq.state.e2e_time_piecewise_normalized) + SequenceMetricsTimeDistributions.REQUEST_E2E_TIME_PIECEWISE_NORMALIZED + ].put(self._get_seq_id(seq.seq_id), seq.state.e2e_time_piecewise_normalized) self.seq_metrics_time_distributions[ - SequenceMetricsTimeDistributions. - REQUEST_EXECUTION_PLUS_PREEMPTION_TIME].put( - self._get_seq_id(seq.seq_id), - seq.state.execution_plus_preemption_time, - ) + SequenceMetricsTimeDistributions.REQUEST_EXECUTION_PLUS_PREEMPTION_TIME + ].put( + self._get_seq_id(seq.seq_id), + seq.state.execution_plus_preemption_time, + ) self.seq_metrics_time_distributions[ - SequenceMetricsTimeDistributions. - REQUEST_EXECUTION_PLUS_PREEMPTION_TIME_NORMALIZED].put( - self._get_seq_id(seq.seq_id), - seq.state.execution_plus_preemption_time_normalized, - ) + SequenceMetricsTimeDistributions.REQUEST_EXECUTION_PLUS_PREEMPTION_TIME_NORMALIZED + ].put( + self._get_seq_id(seq.seq_id), + seq.state.execution_plus_preemption_time_normalized, + ) self.seq_metrics_time_distributions[ - SequenceMetricsTimeDistributions.REQUEST_SCHEDULING_DELAY].put( - self._get_seq_id(seq.seq_id), - seq.state.scheduling_delay, - ) + SequenceMetricsTimeDistributions.REQUEST_SCHEDULING_DELAY + ].put( + self._get_seq_id(seq.seq_id), + seq.state.scheduling_delay, + ) self.seq_metrics_time_distributions[ - SequenceMetricsTimeDistributions.REQUEST_EXECUTION_TIME].put( - self._get_seq_id(seq.seq_id), seq.state.execution_time) + SequenceMetricsTimeDistributions.REQUEST_EXECUTION_TIME + ].put(self._get_seq_id(seq.seq_id), seq.state.execution_time) self.seq_metrics_time_distributions[ - SequenceMetricsTimeDistributions. - REQUEST_EXECUTION_TIME_NORMALIZED].put( - self._get_seq_id(seq.seq_id), - seq.state.execution_time_normalized) + SequenceMetricsTimeDistributions.REQUEST_EXECUTION_TIME_NORMALIZED + ].put(self._get_seq_id(seq.seq_id), seq.state.execution_time_normalized) self.seq_metrics_time_distributions[ - SequenceMetricsTimeDistributions.REQUEST_PREEMPTION_TIME].put( - self._get_seq_id(seq.seq_id), seq.state.preempted_time) + SequenceMetricsTimeDistributions.REQUEST_PREEMPTION_TIME + ].put(self._get_seq_id(seq.seq_id), seq.state.preempted_time) self.seq_metrics_time_distributions[ - SequenceMetricsTimeDistributions.PREFILL_TIME_E2E].put( - self._get_seq_id(seq.seq_id), seq.state.e2e_prefill_time) + SequenceMetricsTimeDistributions.PREFILL_TIME_E2E + ].put(self._get_seq_id(seq.seq_id), seq.state.e2e_prefill_time) self.seq_metrics_time_distributions[ - SequenceMetricsTimeDistributions.PREFILL_TIME_E2E_NORMALIZED].put( - self._get_seq_id(seq.seq_id), - seq.state.e2e_prefill_time_normalized) + SequenceMetricsTimeDistributions.PREFILL_TIME_E2E_NORMALIZED + ].put(self._get_seq_id(seq.seq_id), seq.state.e2e_prefill_time_normalized) self.seq_metrics_time_distributions[ - SequenceMetricsTimeDistributions. - PREFILL_TIME_E2E_PIECEWISE_NORMALIZED].put( - self._get_seq_id(seq.seq_id), - seq.state.e2e_prefill_time_piecewise_normalized) + SequenceMetricsTimeDistributions.PREFILL_TIME_E2E_PIECEWISE_NORMALIZED + ].put( + self._get_seq_id(seq.seq_id), + seq.state.e2e_prefill_time_piecewise_normalized, + ) self.seq_metrics_time_distributions[ - SequenceMetricsTimeDistributions. - PREFILL_TIME_EXECUTION_PLUS_PREEMPTION].put( - self._get_seq_id(seq.seq_id), - seq.state.prefill_execution_plus_preemption_time) + SequenceMetricsTimeDistributions.PREFILL_TIME_EXECUTION_PLUS_PREEMPTION + ].put( + self._get_seq_id(seq.seq_id), + seq.state.prefill_execution_plus_preemption_time, + ) self.seq_metrics_time_distributions[ - SequenceMetricsTimeDistributions. - PREFILL_TIME_EXECUTION_PLUS_PREEMPTION_NORMALIZED].put( - self._get_seq_id(seq.seq_id), - seq.state.prefill_execution_plus_preemption_time_normalized, - ) + SequenceMetricsTimeDistributions.PREFILL_TIME_EXECUTION_PLUS_PREEMPTION_NORMALIZED + ].put( + self._get_seq_id(seq.seq_id), + seq.state.prefill_execution_plus_preemption_time_normalized, + ) self.seq_metrics_time_distributions[ - SequenceMetricsTimeDistributions. - DECODE_TIME_EXECUTION_PLUS_PREEMPTION_NORMALIZED].put( - self._get_seq_id(seq.seq_id), - seq.state.decode_execution_plus_preemption_time_normalized, - ) + SequenceMetricsTimeDistributions.DECODE_TIME_EXECUTION_PLUS_PREEMPTION_NORMALIZED + ].put( + self._get_seq_id(seq.seq_id), + seq.state.decode_execution_plus_preemption_time_normalized, + ) def _update_per_token_execution_times( self, @@ -388,31 +412,37 @@ def _update_per_token_execution_times( # if prefill has just finished in this iteration, update the prefill completion timeseries if seq.get_output_len() == 1: self.completion_metrics_time_series[ - CompletionMetricsTimeSeries.PREFILL_COMPLETIONS].put( - batch_end_time, - seq.state.num_prompt_tokens, - ) + CompletionMetricsTimeSeries.PREFILL_COMPLETIONS + ].put( + batch_end_time, + seq.state.num_prompt_tokens, + ) self.token_metrics_time_distribution[ - TokenMetricsTimeDistribution. - DECODE_TOKEN_EXECUTION_PLUS_PREEMPTION_TIME].put( - seq.state.last_token_generation_time, ) + TokenMetricsTimeDistribution.DECODE_TOKEN_EXECUTION_PLUS_PREEMPTION_TIME + ].put( + seq.state.last_token_generation_time, + ) if self._keep_individual_batch_metrics: self.completion_metrics_time_series[ - CompletionMetricsTimeSeries.DECODE_COMPLETIONS].put( - batch_end_time, 1) + CompletionMetricsTimeSeries.DECODE_COMPLETIONS + ].put(batch_end_time, 1) self.token_metrics_time_list[ - TokenMetricsTimeList. - DECODE_TOKEN_EXECUTION_PLUS_PREEMPTION_TIME_LIST].put( - f"{self._get_seq_id(seq.seq_id)}_{seq.state.num_output_tokens - 1}", - seq.state.last_token_generation_time, - ) + TokenMetricsTimeList.DECODE_TOKEN_EXECUTION_PLUS_PREEMPTION_TIME_LIST + ].put( + f"{self._get_seq_id(seq.seq_id)}_{seq.state.num_output_tokens - 1}", + seq.state.last_token_generation_time, + ) @check_enabled @if_write_metrics - def on_schedule(self, seq_metadata_list: List[SequenceMetadata], - start_time: float, end_time: float) -> None: + def on_schedule( + self, + seq_metadata_list: List[SequenceMetadata], + start_time: float, + end_time: float, + ) -> None: if not self._enable_chrome_trace: return @@ -429,11 +459,15 @@ def on_schedule(self, seq_metadata_list: List[SequenceMetadata], @check_enabled @if_write_metrics - def on_batch_stage_end(self, seq_metadata_list: List[SequenceMetadata], - scheduler_outputs: SchedulerOutputs, - tensor_parallel_rank: int, - pipeline_parallel_rank: int, start_time: float, - end_time: float) -> None: + def on_batch_stage_end( + self, + seq_metadata_list: List[SequenceMetadata], + scheduler_outputs: SchedulerOutputs, + tensor_parallel_rank: int, + pipeline_parallel_rank: int, + start_time: float, + end_time: float, + ) -> None: self._process_individual_batch_metrics() self._next_batch_id = scheduler_outputs.id + 1 if not self._enable_chrome_trace or len(seq_metadata_list) == 0: @@ -464,62 +498,66 @@ def on_batch_end( execution_time = batch_end_time - batch_start_time for seq_metadata in seq_metadata_list: - self._update_per_token_execution_times(batch_end_time, - seq_metadata.seq) + self._update_per_token_execution_times(batch_end_time, seq_metadata.seq) if seq_metadata.seq.is_finished(): self._on_request_end(seq_metadata.seq) if self._last_batch_end_time is not None: self.batch_metrics_time_distribution[ - BatchMetricsTimeDistribution.INTER_BATCH_DELAY].put_pair( - scheduler_outputs.id, - batch_start_time - self._last_batch_end_time, - ) + BatchMetricsTimeDistribution.INTER_BATCH_DELAY + ].put_pair( + scheduler_outputs.id, + batch_start_time - self._last_batch_end_time, + ) self._last_batch_end_time = batch_end_time self.batch_metrics_count_distribution[ - BatchMetricsCountDistribution.BATCH_NUM_TOKENS].put_pair( - scheduler_outputs.id, - scheduler_outputs.num_batched_prompt_tokens + - scheduler_outputs.num_batched_output_tokens, - ) + BatchMetricsCountDistribution.BATCH_NUM_TOKENS + ].put_pair( + scheduler_outputs.id, + scheduler_outputs.num_batched_prompt_tokens + + scheduler_outputs.num_batched_output_tokens, + ) self.batch_metrics_count_distribution[ - BatchMetricsCountDistribution.BATCH_NUM_PREFILL_TOKENS].put_pair( - scheduler_outputs.id, - scheduler_outputs.num_batched_prompt_tokens) + BatchMetricsCountDistribution.BATCH_NUM_PREFILL_TOKENS + ].put_pair(scheduler_outputs.id, scheduler_outputs.num_batched_prompt_tokens) self.batch_metrics_count_distribution[ - BatchMetricsCountDistribution.BATCH_NUM_DECODE_TOKENS].put_pair( - scheduler_outputs.id, - scheduler_outputs.num_batched_output_tokens) + BatchMetricsCountDistribution.BATCH_NUM_DECODE_TOKENS + ].put_pair(scheduler_outputs.id, scheduler_outputs.num_batched_output_tokens) self.batch_metrics_count_distribution[ - BatchMetricsCountDistribution.BATCH_SIZE].put_pair( - scheduler_outputs.id, len(seq_metadata_list)) + BatchMetricsCountDistribution.BATCH_SIZE + ].put_pair(scheduler_outputs.id, len(seq_metadata_list)) # add the only time distribution we have for batch self.batch_metrics_time_distribution[ - BatchMetricsTimeDistribution.BATCH_EXECUTION_TIME].put_pair( - scheduler_outputs.id, execution_time) + BatchMetricsTimeDistribution.BATCH_EXECUTION_TIME + ].put_pair(scheduler_outputs.id, execution_time) - def _to_chrome_trace_dict(self, seq_metadata_list: List[SequenceMetadata], - tensor_parallel_rank: int, - pipeline_parallel_rank: int, start_time: float, - end_time: float) -> Optional[Dict[str, Any]]: + def _to_chrome_trace_dict( + self, + seq_metadata_list: List[SequenceMetadata], + tensor_parallel_rank: int, + pipeline_parallel_rank: int, + start_time: float, + end_time: float, + ) -> Optional[Dict[str, Any]]: if tensor_parallel_rank != 0: return None - seq_ids = [ - seq_metadata.seq.seq_id for seq_metadata in seq_metadata_list - ] + seq_ids = [seq_metadata.seq.seq_id for seq_metadata in seq_metadata_list] prompt_chunk_lens = [ seq_metadata.prompt_chunk_len for seq_metadata in seq_metadata_list ] num_batched_prompt_tokens = sum(prompt_chunk_lens) - num_batched_output_tokens = len([ - seq_metadata for seq_metadata in seq_metadata_list - if not seq_metadata.is_prompt - ]) + num_batched_output_tokens = len( + [ + seq_metadata + for seq_metadata in seq_metadata_list + if not seq_metadata.is_prompt + ] + ) num_batched_tokens = num_batched_prompt_tokens + num_batched_output_tokens @@ -545,8 +583,7 @@ def clear_individual_batch_metrics(self): self.operation_metrics_per_batch_events[metrics_name] = [] def _process_individual_batch_metrics(self): - for metrics_name, events in self.operation_metrics_per_batch_events.items( - ): + for metrics_name, events in self.operation_metrics_per_batch_events.items(): for event in events: start_event, end_event = event time = start_event.elapsed_time(end_event) @@ -565,7 +602,8 @@ def push_operation_metrics_events( return if self._keep_individual_batch_metrics: self.operation_metrics_per_batch_events[metrics_name].append( - [start_event, end_event]) + [start_event, end_event] + ) @check_enabled @if_write_metrics @@ -579,7 +617,8 @@ def push_operation_metrics( self.operation_metrics[metrics_name].put(time) if self._keep_individual_batch_metrics: self.operation_metrics_per_batch[metrics_name].put( - self._next_batch_id, time) + self._next_batch_id, time + ) @check_enabled @if_write_metrics @@ -590,8 +629,7 @@ def push_cpu_operation_metrics( ): if not self._enable_cpu_op_level_metrics: return - self.cpu_operation_metrics[metrics_name].put_pair( - self._next_batch_id, time) + self.cpu_operation_metrics[metrics_name].put_pair(self._next_batch_id, time) def _save_as_csv( self, @@ -603,32 +641,39 @@ def _save_as_csv( os.makedirs(base_path, exist_ok=True) dataseries_dfs = [dataseries.to_df() for dataseries in dataseries_list] - assert ([ + assert [ df[key_to_join].is_unique and pd.notnull(df[key_to_join]) for df in dataseries_dfs - ]) + ] merged_df = reduce( lambda left, right: left.merge(right, on=key_to_join, how="outer"), dataseries_dfs, ) merged_df.to_csv(f"{base_path}/{file_name}.csv", index=False) - def _store_bar_plot(self, base_path: str, plot_name: str, x_label: str, - y_label: str, data: Dict[str, float]): - fig = px.bar(x=list(data.keys()), - y=list(data.values()), - labels={ - "x": x_label, - "y": y_label - }) + def _store_bar_plot( + self, + base_path: str, + plot_name: str, + x_label: str, + y_label: str, + data: Dict[str, float], + ): + fig = px.bar( + x=list(data.keys()), + y=list(data.values()), + labels={"x": x_label, "y": y_label}, + ) if wandb.run: wandb.log( { - plot_name: - wandb.plot.bar( - wandb.Table(dataframe=pd.DataFrame( - data=data.items(), columns=[x_label, y_label])), + plot_name: wandb.plot.bar( + wandb.Table( + dataframe=pd.DataFrame( + data=data.items(), columns=[x_label, y_label] + ) + ), x_label, y_label, title=plot_name, @@ -645,9 +690,9 @@ def _store_request_outputs(self): self.requests_outputs.sort(key=lambda x: int(x.request_id)) with open(f"{self._output_dir}/responses.json", "w") as f: - json.dump([asdict(response) for response in self.requests_outputs], - f, - indent="\t") + json.dump( + [asdict(response) for response in self.requests_outputs], f, indent="\t" + ) def _store_operation_metrics(self, base_plot_path: str): if not self._enable_op_level_metrics and not self._enable_cpu_op_level_metrics: @@ -656,23 +701,27 @@ def _store_operation_metrics(self, base_plot_path: str): total_operation_runtimes: Dict[str, float] = {} for dataseries in self.operation_metrics.values(): - dataseries.plot_cdf(base_plot_path, - f"{dataseries.metric_name}_execution_time", - TIME_STR_MS) + dataseries.plot_cdf( + base_plot_path, f"{dataseries.metric_name}_execution_time", TIME_STR_MS + ) # In `is_op_enabled` we take operations from one of the layers and only rank 0 is considered. - total_operation_runtimes[ - dataseries. - metric_name] = dataseries.sum * self._model_num_layers + total_operation_runtimes[dataseries.metric_name] = ( + dataseries.sum * self._model_num_layers + ) for dataseries in self.cpu_operation_metrics.values(): - dataseries.plot_cdf(base_plot_path, - f"{dataseries.metric_name}_execution_time", - TIME_STR_MS) + dataseries.plot_cdf( + base_plot_path, f"{dataseries.metric_name}_execution_time", TIME_STR_MS + ) total_operation_runtimes[dataseries.metric_name] = dataseries.sum - self._store_bar_plot(base_plot_path, "total_operation_runtimes", - OPERATION_STR, TIME_STR_MS, - total_operation_runtimes) + self._store_bar_plot( + base_plot_path, + "total_operation_runtimes", + OPERATION_STR, + TIME_STR_MS, + total_operation_runtimes, + ) if not self._keep_individual_batch_metrics: return @@ -685,8 +734,7 @@ def _store_operation_metrics(self, base_plot_path: str): y_axis_label=TIME_STR_MS, y_cumsum=False, ) - operations_dataseries_list = list( - self.operation_metrics_per_batch.values()) + operations_dataseries_list = list(self.operation_metrics_per_batch.values()) self._save_as_csv( dataseries_list=operations_dataseries_list, key_to_join=BATCH_ID_STR, @@ -702,8 +750,7 @@ def _store_operation_metrics(self, base_plot_path: str): y_axis_label=TIME_STR_MS, y_cumsum=False, ) - cpu_operations_dataseries_list = list( - self.cpu_operation_metrics.values()) + cpu_operations_dataseries_list = list(self.cpu_operation_metrics.values()) self._save_as_csv( dataseries_list=cpu_operations_dataseries_list, key_to_join=BATCH_ID_STR, @@ -712,9 +759,9 @@ def _store_operation_metrics(self, base_plot_path: str): ) def _store_seq_metrics(self, base_plot_path: str): - all_seq_metrics = list( - self.seq_metrics_time_distributions.values()) + list( - self.seq_metrics_histogram.values()) + all_seq_metrics = list(self.seq_metrics_time_distributions.values()) + list( + self.seq_metrics_histogram.values() + ) self._save_as_csv( dataseries_list=all_seq_metrics, @@ -732,8 +779,8 @@ def _store_seq_metrics(self, base_plot_path: str): def _store_batch_metrics(self, base_plot_path: str): if self._keep_individual_batch_metrics: all_batch_metrics = list( - self.batch_metrics_count_distribution.values()) + list( - self.batch_metrics_time_distribution.values()) + self.batch_metrics_count_distribution.values() + ) + list(self.batch_metrics_time_distribution.values()) self._save_as_csv( dataseries_list=all_batch_metrics, @@ -743,41 +790,46 @@ def _store_batch_metrics(self, base_plot_path: str): ) for dataseries in self.batch_metrics_time_distribution.values(): - dataseries.plot_cdf(base_plot_path, dataseries.metric_name, - TIME_STR) + dataseries.plot_cdf(base_plot_path, dataseries.metric_name, TIME_STR) if self._keep_individual_batch_metrics: - dataseries.plot_step(base_plot_path, - f"{dataseries.metric_name}_per_batch", - y_axis_label=TIME_STR, - y_cumsum=False) + dataseries.plot_step( + base_plot_path, + f"{dataseries.metric_name}_per_batch", + y_axis_label=TIME_STR, + y_cumsum=False, + ), for dataseries in self.batch_metrics_count_distribution.values(): - dataseries.plot_cdf(base_plot_path, dataseries.metric_name, - COUNT_STR) + dataseries.plot_cdf(base_plot_path, dataseries.metric_name, COUNT_STR) if self._keep_individual_batch_metrics: - dataseries.plot_step(base_plot_path, - f"{dataseries.metric_name}_per_batch", - y_axis_label=COUNT_STR, - y_cumsum=False) + dataseries.plot_step( + base_plot_path, + f"{dataseries.metric_name}_per_batch", + y_axis_label=COUNT_STR, + y_cumsum=False, + ), def _store_completion_metrics(self, base_plot_path: str): for dataseries in self.token_metrics_time_distribution.values(): - dataseries.plot_cdf(base_plot_path, dataseries.metric_name, - TIME_STR) + dataseries.plot_cdf(base_plot_path, dataseries.metric_name, TIME_STR) if self._keep_individual_batch_metrics: for dataseries in self.token_metrics_time_list.values(): - dataseries.save_df(path=base_plot_path, - plot_name=dataseries.metric_name) + dataseries.save_df( + path=base_plot_path, plot_name=dataseries.metric_name + ) first_request_arrival_time = self.completion_metrics_time_series[ - CompletionMetricsTimeSeries.REQUEST_ARRIVAL].min_x + CompletionMetricsTimeSeries.REQUEST_ARRIVAL + ].min_x for dataseries in self.completion_metrics_time_series.values(): # subtract the first request arrival time from all the completion times - dataseries.plot_step(base_plot_path, - f"{dataseries.y_name}_time_series", - COUNT_STR, - start_time=first_request_arrival_time) + dataseries.plot_step( + base_plot_path, + f"{dataseries.y_name}_time_series", + COUNT_STR, + start_time=first_request_arrival_time, + ) def _store_chrome_trace(self): if not self._enable_chrome_trace: @@ -789,9 +841,9 @@ def _store_chrome_trace(self): if wandb.run: zip_file_path = f"{self._output_dir}/chrome_trace.zip" - with zipfile.ZipFile(zip_file_path, - "w", - compression=zipfile.ZIP_DEFLATED) as zf: + with zipfile.ZipFile( + zip_file_path, "w", compression=zipfile.ZIP_DEFLATED + ) as zf: zf.writestr( "chrome_trace.json", json.dumps(self.chrome_trace), @@ -815,45 +867,58 @@ def plot(self): def merge(self, other: "MetricsStore"): for metric_name in SequenceMetricsTimeDistributions: self.seq_metrics_time_distributions[metric_name].merge( - other.seq_metrics_time_distributions[metric_name]) + other.seq_metrics_time_distributions[metric_name] + ) for metric_name in TokenMetricsTimeDistribution: self.token_metrics_time_distribution[metric_name].merge( - other.token_metrics_time_distribution[metric_name]) + other.token_metrics_time_distribution[metric_name] + ) if self._keep_individual_batch_metrics: for metric_name in TokenMetricsTimeList: self.token_metrics_time_list[metric_name].merge( - other.token_metrics_time_list[metric_name]) + other.token_metrics_time_list[metric_name] + ) for metric_name in SequenceMetricsHistogram: self.seq_metrics_histogram[metric_name].merge( - other.seq_metrics_histogram[metric_name]) + other.seq_metrics_histogram[metric_name] + ) for metric_name in BatchMetricsCountDistribution: self.batch_metrics_count_distribution[metric_name].merge( - other.batch_metrics_count_distribution[metric_name]) + other.batch_metrics_count_distribution[metric_name] + ) for metric_name in BatchMetricsTimeDistribution: self.batch_metrics_time_distribution[metric_name].merge( - other.batch_metrics_time_distribution[metric_name]) + other.batch_metrics_time_distribution[metric_name] + ) for metric_name in CompletionMetricsTimeSeries: self.completion_metrics_time_series[metric_name].merge( - other.completion_metrics_time_series[metric_name]) + other.completion_metrics_time_series[metric_name] + ) for metric_name in OperationMetrics: - if metric_name in self.operation_metrics and metric_name in other.operation_metrics: + if ( + metric_name in self.operation_metrics + and metric_name in other.operation_metrics + ): self.operation_metrics[metric_name].merge( - other.operation_metrics[metric_name]) + other.operation_metrics[metric_name] + ) for metric_name in OperationMetrics: self.operation_metrics_per_batch[metric_name].elementwise_merge( - other.operation_metrics_per_batch[metric_name]) + other.operation_metrics_per_batch[metric_name] + ) for metric_name in CpuOperationMetrics: self.cpu_operation_metrics[metric_name].merge( - other.cpu_operation_metrics[metric_name]) + other.cpu_operation_metrics[metric_name] + ) self.chrome_trace.extend(other.chrome_trace) self.requests_outputs.extend(other.requests_outputs) diff --git a/sarathi/model_executor/attention/__init__.py b/sarathi/model_executor/attention/__init__.py index 04e4c4c..74db59f 100644 --- a/sarathi/model_executor/attention/__init__.py +++ b/sarathi/model_executor/attention/__init__.py @@ -1,10 +1,18 @@ from enum import Enum from typing import Union -from sarathi.model_executor.attention.flashinfer_attention_wrapper import FlashinferAttentionWrapper -from sarathi.model_executor.attention.flashinfer_unpaged_attention_wrapper import FlashinferUnpagedAttentionWrapper -from sarathi.model_executor.attention.flash_attention_wrapper import FlashAttentionWrapper -from sarathi.model_executor.attention.no_op_attention_wrapper import NoOpAttentionWrapper +from sarathi.model_executor.attention.flash_attention_wrapper import ( + FlashAttentionWrapper, +) +from sarathi.model_executor.attention.flashinfer_attention_wrapper import ( + FlashinferAttentionWrapper, +) +from sarathi.model_executor.attention.flashinfer_unpaged_attention_wrapper import ( + FlashinferUnpagedAttentionWrapper, +) +from sarathi.model_executor.attention.no_op_attention_wrapper import ( + NoOpAttentionWrapper, +) class AttentionBackend(Enum): diff --git a/sarathi/model_executor/attention/base_attention_wrapper.py b/sarathi/model_executor/attention/base_attention_wrapper.py index 4c8f475..9159aad 100644 --- a/sarathi/model_executor/attention/base_attention_wrapper.py +++ b/sarathi/model_executor/attention/base_attention_wrapper.py @@ -1,5 +1,5 @@ -from abc import abstractmethod, ABC -from typing import List, Optional, Union, Tuple +from abc import ABC, abstractmethod +from typing import List, Optional, Tuple, Union import torch @@ -33,12 +33,9 @@ def init( So, we have timers for each layer separately. """ - def get_timer(self, - operation: OperationMetrics, - layer_id: Optional[int] = None): + def get_timer(self, operation: OperationMetrics, layer_id: Optional[int] = None): if self._timers.get((operation, layer_id)) is None: - self._timers[(operation, - layer_id)] = CudaTimer(operation, layer_id) + self._timers[(operation, layer_id)] = CudaTimer(operation, layer_id) return self._timers.get((operation, layer_id)) @abstractmethod diff --git a/sarathi/model_executor/attention/flash_attention_wrapper.py b/sarathi/model_executor/attention/flash_attention_wrapper.py index 2a5cfcf..e740e4e 100644 --- a/sarathi/model_executor/attention/flash_attention_wrapper.py +++ b/sarathi/model_executor/attention/flash_attention_wrapper.py @@ -1,11 +1,11 @@ -import torch +from typing import List, Optional, Tuple +import torch from vllm_flash_attn import flash_attn_with_kvcache -from typing import List, Optional, Tuple -from sarathi.logger import init_logger from sarathi.config import ModelConfig, ParallelConfig from sarathi.core.datatypes.sequence import SequenceMetadata +from sarathi.logger import init_logger from sarathi.metrics.constants import OperationMetrics from sarathi.model_executor.attention.base_attention_wrapper import BaseAttentionWrapper @@ -32,8 +32,9 @@ def init( self.prefill_block_tables: List[torch.Tensor] = None self.decode_block_table: torch.Tensor = None - def get_cache_block(self, num_blocks: int, - **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + def get_cache_block( + self, num_blocks: int, **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor]: k_cache = torch.randn( num_blocks, self.block_size, @@ -76,19 +77,19 @@ def begin_forward( prompt_chunk_len = seq_metadata.prompt_chunk_len current_prompt_chunk_len = seq_metadata.seq.get_next_prompt_chunk_len( - prompt_chunk_len) - processed_prompt_len = seq_metadata.seq.get_num_prompt_tokens_processed( + prompt_chunk_len ) + processed_prompt_len = seq_metadata.seq.get_num_prompt_tokens_processed() current_total_len = processed_prompt_len + current_prompt_chunk_len prefill_query_lens.append(current_prompt_chunk_len) prefill_cache_lens.append([processed_prompt_len]) - num_blocks_in_use = (current_total_len + self.block_size - - 1) // self.block_size - prefill_block_tables.append( - seq_metadata.block_table[:num_blocks_in_use]) + num_blocks_in_use = ( + current_total_len + self.block_size - 1 + ) // self.block_size + prefill_block_tables.append(seq_metadata.block_table[:num_blocks_in_use]) for seq_metadata in seq_metadata_list: if seq_metadata.is_prompt: @@ -112,8 +113,9 @@ def begin_forward( for cache_lens in prefill_cache_lens ] self.prefill_block_tables = [ - torch.tensor(block_table, dtype=torch.int32, - device=self.device).reshape(1, -1) + torch.tensor(block_table, dtype=torch.int32, device=self.device).reshape( + 1, -1 + ) for block_table in prefill_block_tables ] @@ -121,19 +123,18 @@ def begin_forward( # no decode block table return - self.decode_cache_len = torch.tensor(decode_cache_len, - dtype=torch.int32, - device=self.device) + self.decode_cache_len = torch.tensor( + decode_cache_len, dtype=torch.int32, device=self.device + ) - max_decode_blocks = max( - len(seq_block) for seq_block in decode_block_table) + max_decode_blocks = max(len(seq_block) for seq_block in decode_block_table) decode_block_table_padded = [ seq_block + [0] * (max_decode_blocks - len(seq_block)) for seq_block in decode_block_table ] - self.decode_block_table = torch.tensor(decode_block_table_padded, - dtype=torch.int32, - device=self.device) + self.decode_block_table = torch.tensor( + decode_block_table_padded, dtype=torch.int32, device=self.device + ) def end_forward(self): self.is_metadata_initialized = False @@ -165,17 +166,18 @@ def forward( # first process the prefill attention for prefill_cache_len, prefill_block_table, query_len in zip( - self.prefill_cache_lens, self.prefill_block_tables, - self.prefill_query_lens): + self.prefill_cache_lens, self.prefill_block_tables, self.prefill_query_lens + ): with self.get_timer(OperationMetrics.ATTN_INPUT_RESHAPE, layer_id): - seq_query = query[token_offset:token_offset + - query_len].reshape(1, -1, self.num_q_heads, - self.head_dim) - seq_key = key[token_offset:token_offset + query_len].reshape( - 1, -1, self.num_kv_heads, self.head_dim) - seq_value = value[token_offset:token_offset + - query_len].reshape(1, -1, self.num_kv_heads, - self.head_dim) + seq_query = query[token_offset : token_offset + query_len].reshape( + 1, -1, self.num_q_heads, self.head_dim + ) + seq_key = key[token_offset : token_offset + query_len].reshape( + 1, -1, self.num_kv_heads, self.head_dim + ) + seq_value = value[token_offset : token_offset + query_len].reshape( + 1, -1, self.num_kv_heads, self.head_dim + ) with self.get_timer(OperationMetrics.ATTN_PREFILL, layer_id): seq_output = flash_attn_with_kvcache( @@ -190,10 +192,10 @@ def forward( causal=True, ) - with self.get_timer(OperationMetrics.ATTN_OUTPUT_RESHAPE, - layer_id): - output[token_offset:token_offset + query_len].copy_( - seq_output.reshape(-1, self.num_q_heads * self.head_dim)) + with self.get_timer(OperationMetrics.ATTN_OUTPUT_RESHAPE, layer_id): + output[token_offset : token_offset + query_len].copy_( + seq_output.reshape(-1, self.num_q_heads * self.head_dim) + ) token_offset += query_len @@ -203,15 +205,15 @@ def forward( decode_batch_size = self.decode_cache_len.size(0) with self.get_timer(OperationMetrics.ATTN_INPUT_RESHAPE, layer_id): - decode_query = query[token_offset:token_offset + - decode_batch_size].reshape( - -1, 1, self.num_q_heads, self.head_dim) - decode_key = key[token_offset:token_offset + - decode_batch_size].reshape( - -1, 1, self.num_kv_heads, self.head_dim) - decode_value = value[token_offset:token_offset + - decode_batch_size].reshape( - -1, 1, self.num_kv_heads, self.head_dim) + decode_query = query[ + token_offset : token_offset + decode_batch_size + ].reshape(-1, 1, self.num_q_heads, self.head_dim) + decode_key = key[token_offset : token_offset + decode_batch_size].reshape( + -1, 1, self.num_kv_heads, self.head_dim + ) + decode_value = value[ + token_offset : token_offset + decode_batch_size + ].reshape(-1, 1, self.num_kv_heads, self.head_dim) with self.get_timer(OperationMetrics.ATTN_DECODE, layer_id): try: @@ -227,17 +229,21 @@ def forward( causal=True, ) except RuntimeError as e: - if "If key is supplied, it must have seqlen <= the seqlen of the KV cache" in str(e): + if ( + "If key is supplied, it must have seqlen <= the seqlen of the KV cache" + in str(e) + ): logger.warning( "Ran into transient error with flash attention: Key length is greater than the cache length. Skipping the attention computation." ) return output - else: + else: raise e with self.get_timer(OperationMetrics.ATTN_OUTPUT_RESHAPE, layer_id): # flatten the seq_output and copy it to the output tensor - output[token_offset:token_offset + decode_batch_size].copy_( - decode_output.reshape(-1, self.num_q_heads * self.head_dim)) + output[token_offset : token_offset + decode_batch_size].copy_( + decode_output.reshape(-1, self.num_q_heads * self.head_dim) + ) return output diff --git a/sarathi/model_executor/attention/flashinfer_attention_wrapper.py b/sarathi/model_executor/attention/flashinfer_attention_wrapper.py index 767e299..ff65066 100644 --- a/sarathi/model_executor/attention/flashinfer_attention_wrapper.py +++ b/sarathi/model_executor/attention/flashinfer_attention_wrapper.py @@ -1,16 +1,14 @@ +from typing import List, Optional + import torch import torch.nn.functional as F -from flashinfer import ( - append_paged_kv_cache, - BatchPrefillWithPagedKVCacheWrapper, -) -from typing import List, Optional +from flashinfer import BatchPrefillWithPagedKVCacheWrapper, append_paged_kv_cache from sarathi.config import ModelConfig, ParallelConfig from sarathi.core.datatypes.sequence import SequenceMetadata from sarathi.metrics.constants import OperationMetrics -from sarathi.model_executor.utils import round_up_to_multiple from sarathi.model_executor.attention.base_attention_wrapper import BaseAttentionWrapper +from sarathi.model_executor.utils import round_up_to_multiple class FlashinferAttentionWrapper(BaseAttentionWrapper): @@ -25,11 +23,10 @@ def init( ): super().init(model_config, parallel_config, block_size, device) - workspace_buffer = torch.empty(16 * 1024 * 1024, - dtype=torch.uint8, - device=device) - self._wrapper = BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, "NHD") + workspace_buffer = torch.empty( + 16 * 1024 * 1024, dtype=torch.uint8, device=device + ) + self._wrapper = BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD") self.is_metadata_initialized = False self.is_profiling_iteration = False @@ -81,8 +78,7 @@ def begin_forward( continue prompt_chunk_len = seq_metadata.prompt_chunk_len - processed_prompt_len = seq_metadata.seq.get_num_prompt_tokens_processed( - ) + processed_prompt_len = seq_metadata.seq.get_num_prompt_tokens_processed() current_total_len = processed_prompt_len + prompt_chunk_len @@ -96,13 +92,14 @@ def begin_forward( # indptr for the prompt tokens in q/o tensor qo_indptr.append(qo_indptr[-1] + prompt_chunk_len) # Compute the kv page indices for the prompt tokens. - num_blocks_in_use = (current_total_len + self.block_size - - 1) // self.block_size - kv_page_indices.extend( - seq_metadata.block_table[:num_blocks_in_use]) + num_blocks_in_use = ( + current_total_len + self.block_size - 1 + ) // self.block_size + kv_page_indices.extend(seq_metadata.block_table[:num_blocks_in_use]) kv_page_indptr.append(kv_page_indptr[-1] + num_blocks_in_use) - kv_last_page_len.append(current_total_len % self.block_size - or self.block_size) + kv_last_page_len.append( + current_total_len % self.block_size or self.block_size + ) for seq_metadata in seq_metadata_list: if seq_metadata.is_prompt: @@ -117,24 +114,20 @@ def begin_forward( qo_indptr.append(qo_indptr[-1] + 1) # Compute the kv page indices for the prompt tokens. kv_page_indices.extend(seq_metadata.block_table) - kv_page_indptr.append(kv_page_indptr[-1] + - len(seq_metadata.block_table)) - kv_last_page_len.append(context_len % self.block_size - or self.block_size) + kv_page_indptr.append(kv_page_indptr[-1] + len(seq_metadata.block_table)) + kv_last_page_len.append(context_len % self.block_size or self.block_size) # Convert to tensors. - self.qo_indptr = torch.tensor(qo_indptr, - dtype=torch.int32, - device=self.device) - self.kv_page_indices = torch.tensor(kv_page_indices, - dtype=torch.int32, - device=self.device) - self.kv_page_indptr = torch.tensor(kv_page_indptr, - dtype=torch.int32, - device=self.device) - self.kv_last_page_len = torch.tensor(kv_last_page_len, - dtype=torch.int32, - device=self.device) + self.qo_indptr = torch.tensor(qo_indptr, dtype=torch.int32, device=self.device) + self.kv_page_indices = torch.tensor( + kv_page_indices, dtype=torch.int32, device=self.device + ) + self.kv_page_indptr = torch.tensor( + kv_page_indptr, dtype=torch.int32, device=self.device + ) + self.kv_last_page_len = torch.tensor( + kv_last_page_len, dtype=torch.int32, device=self.device + ) self._wrapper.begin_forward( self.qo_indptr, @@ -166,12 +159,9 @@ def forward( return torch.zeros_like(query) with self.get_timer(OperationMetrics.ATTN_INPUT_RESHAPE, layer_id): - query = query.contiguous().reshape(-1, self.num_q_heads, - self.head_dim) - key = key.contiguous().reshape(-1, self.num_kv_heads, - self.head_dim) - value = value.contiguous().reshape(-1, self.num_kv_heads, - self.head_dim) + query = query.contiguous().reshape(-1, self.num_q_heads, self.head_dim) + key = key.contiguous().reshape(-1, self.num_kv_heads, self.head_dim) + value = value.contiguous().reshape(-1, self.num_kv_heads, self.head_dim) with self.get_timer(OperationMetrics.ATTN_KV_CACHE_SAVE, layer_id): append_paged_kv_cache( diff --git a/sarathi/model_executor/attention/flashinfer_unpaged_attention_wrapper.py b/sarathi/model_executor/attention/flashinfer_unpaged_attention_wrapper.py index 00708b7..e0b9e77 100644 --- a/sarathi/model_executor/attention/flashinfer_unpaged_attention_wrapper.py +++ b/sarathi/model_executor/attention/flashinfer_unpaged_attention_wrapper.py @@ -2,16 +2,16 @@ import torch from flashinfer import ( - single_prefill_with_kv_cache, - append_paged_kv_cache, BatchDecodeWithPagedKVCacheWrapper, + append_paged_kv_cache, + single_prefill_with_kv_cache, ) from sarathi.config import ModelConfig, ParallelConfig from sarathi.core.datatypes.sequence import SequenceMetadata +from sarathi.metrics.constants import OperationMetrics from sarathi.model_executor.attention.base_attention_wrapper import BaseAttentionWrapper from sarathi.model_executor.attention.kv_buffer import KVBuffer -from sarathi.metrics.constants import OperationMetrics class FlashinferUnpagedAttentionWrapper(BaseAttentionWrapper): @@ -26,11 +26,10 @@ def init( ): super().init(model_config, parallel_config, block_size, device) - workspace_buffer = torch.empty(16 * 1024 * 1024, - dtype=torch.uint8, - device=device) - self._wrapper = BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, "NHD") + workspace_buffer = torch.empty( + 16 * 1024 * 1024, dtype=torch.uint8, device=device + ) + self._wrapper = BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD") self.kv_buffers: List[KVBuffer] = [] num_layers = model_config.get_num_layers(parallel_config) @@ -41,7 +40,7 @@ def init( self.num_kv_heads, self.head_dim, device, - self.dtype + self.dtype, ) ) @@ -93,7 +92,7 @@ def begin_forward( # in the ragged tensor. kv_page_indptr: List[int] = [0] decode_kv_page_indptr: List[int] = [0] - # we also create a qo_indptr tensor to capture the start of each sequence in the + # we also create a qo_indptr tensor to capture the start of each sequence in the # ragged tensor which is used for the kv cache append api. # qo_indptr: [0, prompt_0, prompt_0 + prompt_1, ..., prompt_0 + ... + prompt_N-1, generation_0, generation_0 + 1, ..., generation_0 + ... + M] qo_indptr: List[int] = [0] @@ -130,13 +129,14 @@ def begin_forward( # indptr for the prompt tokens in q/o tensor qo_indptr.append(qo_indptr[-1] + prompt_chunk_len) # Compute the kv page indices for the prompt tokens. - num_blocks_in_use = (current_total_len + self.block_size - - 1) // self.block_size - kv_page_indices.extend( - seq_metadata.block_table[:num_blocks_in_use]) + num_blocks_in_use = ( + current_total_len + self.block_size - 1 + ) // self.block_size + kv_page_indices.extend(seq_metadata.block_table[:num_blocks_in_use]) kv_page_indptr.append(kv_page_indptr[-1] + num_blocks_in_use) - kv_last_page_len.append(current_total_len % self.block_size - or self.block_size) + kv_last_page_len.append( + current_total_len % self.block_size or self.block_size + ) for seq_metadata in seq_metadata_list: if seq_metadata.block_table is None: @@ -154,37 +154,35 @@ def begin_forward( # Compute the kv page indices for the prompt tokens. kv_page_indices.extend(seq_metadata.block_table) decode_kv_page_indices.extend(seq_metadata.block_table) - kv_page_indptr.append(kv_page_indptr[-1] + - len(seq_metadata.block_table)) - decode_kv_page_indptr.append(decode_kv_page_indptr[-1] + - len(seq_metadata.block_table)) - kv_last_page_len.append(context_len % self.block_size - or self.block_size) - decode_kv_last_page_len.append(context_len % self.block_size - or self.block_size) + kv_page_indptr.append(kv_page_indptr[-1] + len(seq_metadata.block_table)) + decode_kv_page_indptr.append( + decode_kv_page_indptr[-1] + len(seq_metadata.block_table) + ) + kv_last_page_len.append(context_len % self.block_size or self.block_size) + decode_kv_last_page_len.append( + context_len % self.block_size or self.block_size + ) # Convert to tensors. - self.qo_indptr = torch.tensor(qo_indptr, - dtype=torch.int32, - device=self.device) - self.kv_page_indices = torch.tensor(kv_page_indices, - dtype=torch.int32, - device=self.device) - self.kv_page_indptr = torch.tensor(kv_page_indptr, - dtype=torch.int32, - device=self.device) - self.kv_last_page_len = torch.tensor(kv_last_page_len, - dtype=torch.int32, - device=self.device) - decode_kv_page_indices = torch.tensor(decode_kv_page_indices, - dtype=torch.int32, - device=self.device) - decode_kv_page_indptr = torch.tensor(decode_kv_page_indptr, - dtype=torch.int32, - device=self.device) - decode_kv_last_page_len = torch.tensor(decode_kv_last_page_len, - dtype=torch.int32, - device=self.device) + self.qo_indptr = torch.tensor(qo_indptr, dtype=torch.int32, device=self.device) + self.kv_page_indices = torch.tensor( + kv_page_indices, dtype=torch.int32, device=self.device + ) + self.kv_page_indptr = torch.tensor( + kv_page_indptr, dtype=torch.int32, device=self.device + ) + self.kv_last_page_len = torch.tensor( + kv_last_page_len, dtype=torch.int32, device=self.device + ) + decode_kv_page_indices = torch.tensor( + decode_kv_page_indices, dtype=torch.int32, device=self.device + ) + decode_kv_page_indptr = torch.tensor( + decode_kv_page_indptr, dtype=torch.int32, device=self.device + ) + decode_kv_last_page_len = torch.tensor( + decode_kv_last_page_len, dtype=torch.int32, device=self.device + ) self.prompt_seq_ids = prompt_seq_ids self.prompt_chunk_lens = prompt_chunk_lens @@ -222,15 +220,12 @@ def forward( # there is no need to call attention in profiling mode return torch.zeros_like(query) - output = torch.empty_like(query).view(-1, self.num_q_heads,self.head_dim) + output = torch.empty_like(query).view(-1, self.num_q_heads, self.head_dim) with self.get_timer(OperationMetrics.ATTN_INPUT_RESHAPE, layer_id): - query = query.contiguous().reshape(-1, self.num_q_heads, - self.head_dim) - key = key.contiguous().reshape(-1, self.num_kv_heads, - self.head_dim) - value = value.contiguous().reshape(-1, self.num_kv_heads, - self.head_dim) + query = query.contiguous().reshape(-1, self.num_q_heads, self.head_dim) + key = key.contiguous().reshape(-1, self.num_kv_heads, self.head_dim) + value = value.contiguous().reshape(-1, self.num_kv_heads, self.head_dim) qo_offset: int = 0 for i, seq_id in enumerate(self.prompt_seq_ids): @@ -240,15 +235,22 @@ def forward( processed_prompt_len = self.processed_prompt_lens[i] total_prompt_len = self.total_prompt_lens[i] - q = query[qo_offset:qo_offset+prompt_chunk_len] - k = key[qo_offset:qo_offset+prompt_chunk_len] - v = value[qo_offset:qo_offset+prompt_chunk_len] + q = query[qo_offset : qo_offset + prompt_chunk_len] + k = key[qo_offset : qo_offset + prompt_chunk_len] + v = value[qo_offset : qo_offset + prompt_chunk_len] if prompt_chunk_len == total_prompt_len: # if all the tokens are processed at once, we can skip the kv buffer management with self.get_timer(OperationMetrics.ATTN, layer_id): - output[qo_offset:qo_offset+prompt_chunk_len] = single_prefill_with_kv_cache( - q, k, v, causal=True, pos_encoding_mode="NONE", sm_scale=softmax_scale + output[qo_offset : qo_offset + prompt_chunk_len] = ( + single_prefill_with_kv_cache( + q, + k, + v, + causal=True, + pos_encoding_mode="NONE", + sm_scale=softmax_scale, + ) ) else: if seq_id not in kv_buffer.buffer_indices: @@ -257,11 +259,18 @@ def forward( kv_buffer.append(seq_id, k, v) k_, v_ = kv_buffer.get_kv_tensors(seq_id) with self.get_timer(OperationMetrics.ATTN, layer_id): - output[qo_offset:qo_offset+prompt_chunk_len] = single_prefill_with_kv_cache( - q, k_, v_, causal=True, pos_encoding_mode="NONE", sm_scale=softmax_scale + output[qo_offset : qo_offset + prompt_chunk_len] = ( + single_prefill_with_kv_cache( + q, + k_, + v_, + causal=True, + pos_encoding_mode="NONE", + sm_scale=softmax_scale, + ) ) - if total_prompt_len == processed_prompt_len + prompt_chunk_len: + if total_prompt_len == processed_prompt_len + prompt_chunk_len: kv_buffer.free_request(seq_id) qo_offset += prompt_chunk_len @@ -278,20 +287,20 @@ def forward( kv_layout="NHD", ) - if self.decode_batch_size > 0: with self.get_timer(OperationMetrics.ATTN, layer_id): - output[qo_offset:qo_offset+self.decode_batch_size] = self._wrapper.forward( - query[qo_offset:qo_offset+self.decode_batch_size], - kv_cache, - pos_encoding_mode="NONE", - sm_scale=softmax_scale, + output[qo_offset : qo_offset + self.decode_batch_size] = ( + self._wrapper.forward( + query[qo_offset : qo_offset + self.decode_batch_size], + kv_cache, + pos_encoding_mode="NONE", + sm_scale=softmax_scale, + ) ) qo_offset += self.decode_batch_size with self.get_timer(OperationMetrics.ATTN_OUTPUT_RESHAPE, layer_id): - output = output.reshape(-1, - self.num_q_heads * self.head_dim) + output = output.reshape(-1, self.num_q_heads * self.head_dim) self.layer_index += 1 assert self.layer_index <= len(self.kv_buffers) diff --git a/sarathi/model_executor/attention/kv_buffer.py b/sarathi/model_executor/attention/kv_buffer.py index f0d5ba6..2d7d7d4 100644 --- a/sarathi/model_executor/attention/kv_buffer.py +++ b/sarathi/model_executor/attention/kv_buffer.py @@ -1,6 +1,7 @@ -import torch from typing import Dict, Tuple +import torch + class KVBuffer: """ @@ -63,8 +64,7 @@ def get_kv_tensors(self, seq_id: int) -> Tuple[torch.Tensor, torch.Tensor]: self.v_buffer[start_offset:end_offset], ) - def append(self, seq_id: int, key: torch.Tensor, - value: torch.Tensor) -> None: + def append(self, seq_id: int, key: torch.Tensor, value: torch.Tensor) -> None: assert key.shape == value.shape active_length = self.buffer_active_lens[seq_id] assert active_length + key.shape[0] <= self.max_seq_len diff --git a/sarathi/model_executor/attention/no_op_attention_wrapper.py b/sarathi/model_executor/attention/no_op_attention_wrapper.py index d2f602e..8e9a55d 100644 --- a/sarathi/model_executor/attention/no_op_attention_wrapper.py +++ b/sarathi/model_executor/attention/no_op_attention_wrapper.py @@ -1,6 +1,7 @@ +from typing import List, Optional, Tuple + import torch -from typing import List, Optional, Tuple from sarathi.config import ModelConfig, ParallelConfig from sarathi.core.datatypes.sequence import SequenceMetadata from sarathi.model_executor.attention.base_attention_wrapper import BaseAttentionWrapper @@ -18,8 +19,9 @@ def init( ): self.device = device - def get_cache_block(self, num_blocks: int, - **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + def get_cache_block( + self, num_blocks: int, **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor]: pass def begin_forward( diff --git a/sarathi/model_executor/layers/activation.py b/sarathi/model_executor/layers/activation.py index c15b723..90c5797 100644 --- a/sarathi/model_executor/layers/activation.py +++ b/sarathi/model_executor/layers/activation.py @@ -1,4 +1,5 @@ """Custom activation functions.""" + import torch import torch.nn as nn diff --git a/sarathi/model_executor/layers/layernorm.py b/sarathi/model_executor/layers/layernorm.py index 044677c..090e422 100644 --- a/sarathi/model_executor/layers/layernorm.py +++ b/sarathi/model_executor/layers/layernorm.py @@ -1,8 +1,10 @@ """Custom normalization layers.""" + +from typing import Optional + import torch import torch.nn as nn -from typing import Optional from sarathi import layernorm_ops from sarathi.metrics.cuda_timer import CudaTimer diff --git a/sarathi/model_executor/layers/rotary_embedding.py b/sarathi/model_executor/layers/rotary_embedding.py index 3f67067..f6bacc8 100644 --- a/sarathi/model_executor/layers/rotary_embedding.py +++ b/sarathi/model_executor/layers/rotary_embedding.py @@ -63,17 +63,19 @@ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: # use CPU to compute the cache and then move it to GPU. However, we # create the cache on GPU for faster initialization. This may cause # a slight numerical difference between the HF implementation and ours. - inv_freq = 1.0 / (base**(torch.arange( - 0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / - self.rotary_dim)) + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device="cuda") + / self.rotary_dim + ) + ) return inv_freq def _compute_cos_sin_cache(self) -> torch.Tensor: """Compute the cos and sin cache.""" inv_freq = self._compute_inv_freq(self.base) - t = torch.arange(self.max_position_embeddings, - dtype=torch.float, - device="cuda") + t = torch.arange(self.max_position_embeddings, dtype=torch.float, device="cuda") freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = freqs.cos() @@ -89,9 +91,14 @@ def forward( ) -> Tuple[torch.Tensor, torch.Tensor]: # pos_encoding_ops.rotary_embedding() is an in-place operation that # updates the query and key tensors. - pos_encoding_ops.rotary_embedding(positions, query, key, - self.head_size, self.cos_sin_cache, - self.is_neox_style) + pos_encoding_ops.rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) return query, key @@ -111,8 +118,9 @@ def __init__( scaling_factor: float, ) -> None: self.scaling_factor = scaling_factor - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style) + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style + ) def _compute_cos_sin_cache(self) -> torch.Tensor: inv_freq = self._compute_inv_freq(self.base) @@ -147,8 +155,9 @@ def __init__( scaling_factor: float, ) -> None: self.scaling_factor = scaling_factor - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style) + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style + ) def _compute_cos_sin_cache(self) -> torch.Tensor: # NOTE(woosuk): self.max_position_embeddings is the original @@ -157,9 +166,9 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: # self.max_position_embeddings * self.scaling_factor. max_len = self.max_position_embeddings * self.scaling_factor base = self.base * ( - (self.scaling_factor * max_len / self.max_position_embeddings) - - (self.scaling_factor - 1))**(self.rotary_dim / - (self.rotary_dim - 2)) + (self.scaling_factor * max_len / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.rotary_dim / (self.rotary_dim - 2)) inv_freq = self._compute_inv_freq(base) t = torch.arange(max_len, dtype=torch.float, device="cuda") @@ -171,37 +180,41 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: # Inverse dim formula to find dim based on number of rotations -def _yarn_find_correction_dim(num_rotations: int, - dim: int, - base: float = 10000, - max_position_embeddings: int = 2048) -> float: - return (dim * math.log(max_position_embeddings / - (num_rotations * 2 * math.pi))) / (2 * - math.log(base)) +def _yarn_find_correction_dim( + num_rotations: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> float: + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) # Find dim range bounds based on rotations -def _yarn_find_correction_range(low_rot: int, - high_rot: int, - dim: int, - base: float = 10000, - max_position_embeddings: int = 2048) -> int: +def _yarn_find_correction_range( + low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> int: low = math.floor( - _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) high = math.ceil( - _yarn_find_correction_dim(high_rot, dim, base, - max_position_embeddings)) + _yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) return max(low, 0), min(high, dim - 1) # Clamp values just in case -def _yarn_linear_ramp_mask(low: float, high: float, dim: int, - dtype: torch.dtype, - device: torch.device) -> torch.Tensor: +def _yarn_linear_ramp_mask( + low: float, high: float, dim: int, dtype: torch.dtype, device: torch.device +) -> torch.Tensor: if low == high: high += 0.001 # Prevent singularity - linear_func = (torch.arange(dim, dtype=dtype, device=device) - - low) / (high - low) + linear_func = (torch.arange(dim, dtype=dtype, device=device) - low) / (high - low) ramp_func = torch.clamp(linear_func, 0, 1) return ramp_func @@ -238,37 +251,49 @@ def __init__( self.beta_fast = beta_fast self.beta_slow = beta_slow # Get n-d magnitude scaling corrected for interpolation - self.mscale = float( - _yarn_get_mscale(self.scaling_factor) * attn_factor) - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style) + self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor) + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style + ) def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: - pos_freqs = self.base**(torch.arange( - 0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / - self.rotary_dim) + pos_freqs = self.base ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device="cuda") + / self.rotary_dim + ) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) - low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, - self.rotary_dim, self.base, - self.max_position_embeddings) + low, high = _yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + self.rotary_dim, + self.base, + self.max_position_embeddings, + ) # Get n-d rotational scaling corrected for extrapolation - inv_freq_mask = (1 - _yarn_linear_ramp_mask( - low, high, self.rotary_dim // 2, dtype=torch.float, - device="cuda")) * self.extrapolation_factor - inv_freq = inv_freq_interpolation * ( - 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + inv_freq_mask = ( + 1 + - _yarn_linear_ramp_mask( + low, high, self.rotary_dim // 2, dtype=torch.float, device="cuda" + ) + ) * self.extrapolation_factor + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) return inv_freq def _compute_cos_sin_cache(self) -> torch.Tensor: inv_freq = self._compute_inv_freq(self.scaling_factor) - t = torch.arange(self.max_position_embeddings * self.scaling_factor, - device="cuda", - dtype=torch.float32) + t = torch.arange( + self.max_position_embeddings * self.scaling_factor, + device="cuda", + dtype=torch.float32, + ) freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = (freqs.cos() * self.mscale) - sin = (freqs.sin() * self.mscale) + cos = freqs.cos() * self.mscale + sin = freqs.sin() * self.mscale cache = torch.cat((cos, sin), dim=-1) return cache @@ -282,35 +307,38 @@ def get_rope( rope_scaling: Optional[Dict[str, Any]], ) -> RotaryEmbedding: if rope_scaling is None: - rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, - is_neox_style) + rotary_emb = RotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style + ) else: scaling_type = rope_scaling["type"] scaling_factor = rope_scaling["factor"] if scaling_type == "linear": - rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, - max_position, base, - is_neox_style, - scaling_factor) + rotary_emb = LinearScalingRotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, scaling_factor + ) elif scaling_type == "dynamic": rotary_emb = DynamicNTKScalingRotaryEmbedding( - head_size, rotary_dim, max_position, base, is_neox_style, - scaling_factor) + head_size, rotary_dim, max_position, base, is_neox_style, scaling_factor + ) elif scaling_type == "yarn": - original_max_position = rope_scaling[ - "original_max_position_embeddings"] + original_max_position = rope_scaling["original_max_position_embeddings"] assert max_position == original_max_position * scaling_factor extra_kwargs = { k: v for k, v in rope_scaling.items() - if k in ("extrapolation_factor", "attn_factor", "beta_fast", - "beta_slow") + if k + in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow") } - rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim, - original_max_position, - base, is_neox_style, - scaling_factor, - **extra_kwargs) + rotary_emb = YaRNScalingRotaryEmbedding( + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + scaling_factor, + **extra_kwargs, + ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") return rotary_emb diff --git a/sarathi/model_executor/layers/sampler.py b/sarathi/model_executor/layers/sampler.py index 594fab8..6392f74 100644 --- a/sarathi/model_executor/layers/sampler.py +++ b/sarathi/model_executor/layers/sampler.py @@ -1,13 +1,19 @@ """A layer that samples the next tokens from the model's outputs.""" + from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn -from sarathi.model_executor.parallel_utils.tensor_parallel import ( - gather_from_tensor_model_parallel_region) from sarathi.core.datatypes.sampling_params import SamplingType -from sarathi.core.datatypes.sequence import SamplerOutputs, SamplerOutput, SequenceMetadata +from sarathi.core.datatypes.sequence import ( + SamplerOutput, + SamplerOutputs, + SequenceMetadata, +) +from sarathi.model_executor.parallel_utils.tensor_parallel import ( + gather_from_tensor_model_parallel_region, +) _SAMPLING_EPS = 1e-5 @@ -47,9 +53,7 @@ def forward( temperatures = _get_temperatures(seq_metadata_list) assert len(temperatures) == logits.shape[0] if any(t != 1.0 for t in temperatures): - t = torch.tensor(temperatures, - dtype=logits.dtype, - device=logits.device) + t = torch.tensor(temperatures, dtype=logits.dtype, device=logits.device) # Use in-place division to avoid creating a new tensor. logits.div_(t.unsqueeze(dim=1)) @@ -72,8 +76,9 @@ def forward( return _sample(probs, logprobs, seq_metadata_list) -def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor, - vocab_size: int) -> torch.Tensor: +def _get_logits( + hidden_states: torch.Tensor, embedding: torch.Tensor, vocab_size: int +) -> torch.Tensor: # Get the logits for the next tokens. logits = torch.matmul(hidden_states, embedding.t()) logits = gather_from_tensor_model_parallel_region(logits) @@ -97,14 +102,13 @@ def _prune_hidden_states( last_token_indices.append(token_idx) token_idx += 1 - last_token_indices = torch.tensor(last_token_indices, - dtype=torch.long, - device=hidden_states.device) + last_token_indices = torch.tensor( + last_token_indices, dtype=torch.long, device=hidden_states.device + ) return hidden_states.index_select(0, last_token_indices) -def _get_temperatures( - seq_metadata_list: List[SequenceMetadata]) -> List[float]: +def _get_temperatures(seq_metadata_list: List[SequenceMetadata]) -> List[float]: # Collect the temperatures for the logits. temperatures: List[float] = [] for seq_metadata in seq_metadata_list: @@ -158,20 +162,25 @@ def _apply_top_p_top_k( logits_sort[top_k_mask] = -float("inf") # Re-sort the probabilities. - logits = torch.gather(logits_sort, - dim=-1, - index=torch.argsort(logits_idx, dim=-1)) + logits = torch.gather(logits_sort, dim=-1, index=torch.argsort(logits_idx, dim=-1)) return logits def _greedy_sample( - logprobs: torch.Tensor, ) -> List[Tuple[List[int], List[int]]]: + logprobs: torch.Tensor, +) -> List[Tuple[List[int], List[int]]]: return torch.argmax(logprobs, dim=-1).view(-1).cpu().tolist() -def _random_sample(probs: torch.Tensor, ) -> List[Tuple[List[int], List[int]]]: - random_samples = torch.multinomial( - probs, num_samples=1, replacement=True).view(-1).cpu().tolist() +def _random_sample( + probs: torch.Tensor, +) -> List[Tuple[List[int], List[int]]]: + random_samples = ( + torch.multinomial(probs, num_samples=1, replacement=True) + .view(-1) + .cpu() + .tolist() + ) return random_samples @@ -196,10 +205,10 @@ def _sample( num_tokens = category_num_tokens[sampling_type] if num_tokens == 0: continue - category_logprobs = logprobs[category_start_idx:category_start_idx + - num_tokens] - category_probs = probs[category_start_idx:category_start_idx + - num_tokens] + category_logprobs = logprobs[ + category_start_idx : category_start_idx + num_tokens + ] + category_probs = probs[category_start_idx : category_start_idx + num_tokens] if sampling_type == SamplingType.GREEDY: sample_results = _greedy_sample(category_logprobs) elif sampling_type == SamplingType.RANDOM: diff --git a/sarathi/model_executor/model_loader.py b/sarathi/model_executor/model_loader.py index 825daee..6baa299 100644 --- a/sarathi/model_executor/model_loader.py +++ b/sarathi/model_executor/model_loader.py @@ -1,4 +1,5 @@ """Utilities for selecting and loading models.""" + import contextlib from typing import Type @@ -38,7 +39,8 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: return _MODEL_REGISTRY[arch] raise ValueError( f"Model architectures {architectures} are not supported for now. " - f"Supported architectures: {list(_MODEL_REGISTRY.keys())}") + f"Supported architectures: {list(_MODEL_REGISTRY.keys())}" + ) def get_model(model_config: ModelConfig) -> nn.Module: @@ -55,6 +57,10 @@ def get_model(model_config: ModelConfig) -> nn.Module: initialize_dummy_weights(model) else: # Load the weights from the cached or downloaded files. - model.load_weights(model_config.model, model_config.download_dir, - model_config.load_format, model_config.revision) + model.load_weights( + model_config.model, + model_config.download_dir, + model_config.load_format, + model_config.revision, + ) return model.eval() diff --git a/sarathi/model_executor/model_runner.py b/sarathi/model_executor/model_runner.py index 92427e0..37c2ade 100644 --- a/sarathi/model_executor/model_runner.py +++ b/sarathi/model_executor/model_runner.py @@ -1,22 +1,27 @@ -from typing import List, Tuple, Optional +from typing import List, Optional, Tuple import torch import torch.distributed -from sarathi.logger import init_logger -from sarathi.config import (CacheConfig, ModelConfig, ParallelConfig, - BaseSchedulerConfig, SchedulerType) -from sarathi.model_executor import get_model, set_random_seed +from sarathi.config import ( + BaseSchedulerConfig, + CacheConfig, + ModelConfig, + ParallelConfig, + SchedulerType, +) from sarathi.core.datatypes.sampling_params import SamplingParams -from sarathi.model_executor.layers.sampler import Sampler from sarathi.core.datatypes.sequence import Sequence, SequenceMetadata -from sarathi.worker.cache_engine import CacheEngine -from sarathi.utils import get_gpu_memory +from sarathi.logger import init_logger from sarathi.metrics.constants import CpuOperationMetrics, OperationMetrics from sarathi.metrics.cpu_timer import CpuTimer from sarathi.metrics.cuda_timer import CudaTimer +from sarathi.model_executor import get_model, set_random_seed from sarathi.model_executor.attention import get_attention_wrapper +from sarathi.model_executor.layers.sampler import Sampler from sarathi.model_executor.utils import pad_to_alignment +from sarathi.utils import get_gpu_memory +from sarathi.worker.cache_engine import CacheEngine logger = init_logger(__name__) @@ -48,15 +53,19 @@ def __init__( self.sampler: Optional[Sampler] = None if self.model.lm_head: - self.sampler = Sampler(self.model.lm_head.weight, - self.model.config.vocab_size) + self.sampler = Sampler( + self.model.lm_head.weight, self.model.config.vocab_size + ) self._prepare_inputs_e2e_timer = CpuTimer( - CpuOperationMetrics.PREPARE_INPUTS_E2E, rank=self.rank) - self._sampler_e2e_timer = CpuTimer(CpuOperationMetrics.SAMPLER_E2E, - rank=self.rank) + CpuOperationMetrics.PREPARE_INPUTS_E2E, rank=self.rank + ) + self._sampler_e2e_timer = CpuTimer( + CpuOperationMetrics.SAMPLER_E2E, rank=self.rank + ) self._model_execution_e2e_timer = CpuTimer( - CpuOperationMetrics.MODEL_EXECUTION_E2E, rank=self.rank) + CpuOperationMetrics.MODEL_EXECUTION_E2E, rank=self.rank + ) def _prepare_inputs( self, @@ -72,18 +81,17 @@ def _prepare_inputs( continue prompt_chunk_len = seq_metadata.prompt_chunk_len - current_prompt_chunk_tokens = seq_metadata.seq.get_next_prompt_chunk_token_ids( - prompt_chunk_len) + current_prompt_chunk_tokens = ( + seq_metadata.seq.get_next_prompt_chunk_token_ids(prompt_chunk_len) + ) current_prompt_chunk_len = len(current_prompt_chunk_tokens) current_prompt_chunk_lens.append(current_prompt_chunk_len) - processed_prompt_len = seq_metadata.seq.get_num_prompt_tokens_processed( - ) + processed_prompt_len = seq_metadata.seq.get_num_prompt_tokens_processed() current_total_len = processed_prompt_len + current_prompt_chunk_len input_tokens.extend(current_prompt_chunk_tokens) - input_positions.extend( - range(processed_prompt_len, current_total_len)) + input_positions.extend(range(processed_prompt_len, current_total_len)) for seq_metadata in seq_metadata_list: if seq_metadata.is_prompt: @@ -102,12 +110,10 @@ def _prepare_inputs( input_positions = pad_to_alignment(input_positions, multiple_of=8) # Convert to tensors. - tokens_tensor = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - positions_tensor = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) + tokens_tensor = torch.tensor(input_tokens, dtype=torch.long, device=self.device) + positions_tensor = torch.tensor( + input_positions, dtype=torch.long, device=self.device + ) return tokens_tensor, positions_tensor @@ -132,7 +138,10 @@ def profile_num_available_blocks( seq_metadata_list: List[SequenceMetadata] = [] - if self.scheduler_config.type == SchedulerType.SARATHI or self.scheduler_config.type == SchedulerType.SIMPLE_CHUNKING: + if ( + self.scheduler_config.type == SchedulerType.SARATHI + or self.scheduler_config.type == SchedulerType.SIMPLE_CHUNKING + ): # Profile memory usage with a single `chunk_size` chunk # which is the last chunk in the longest supported sequence. chunk_size = self.scheduler_config.chunk_size @@ -157,8 +166,9 @@ def profile_num_available_blocks( # Profile memory usage with max_num_sequences sequences and the total # number of tokens equal to max_num_batched_tokens. for seq_id in range(max_num_seqs): - seq_len = (max_num_batched_tokens // max_num_seqs + - (seq_id < max_num_batched_tokens % max_num_seqs)) + seq_len = max_num_batched_tokens // max_num_seqs + ( + seq_id < max_num_batched_tokens % max_num_seqs + ) seq = Sequence( seq_id=seq_id, @@ -193,10 +203,12 @@ def profile_num_available_blocks( peak_memory = torch.cuda.max_memory_allocated() total_gpu_memory = get_gpu_memory() cache_block_size = CacheEngine.get_cache_block_size( - block_size, self.model_config, self.parallel_config) + block_size, self.model_config, self.parallel_config + ) num_gpu_blocks = int( - (total_gpu_memory * gpu_memory_utilization - peak_memory) // - cache_block_size) + (total_gpu_memory * gpu_memory_utilization - peak_memory) + // cache_block_size + ) num_gpu_blocks = max(num_gpu_blocks, 0) torch.cuda.empty_cache() @@ -214,8 +226,7 @@ def run( ) -> torch.Tensor: # Prepare input tensors. with self._prepare_inputs_e2e_timer: - input_tokens, input_positions = self._prepare_inputs( - seq_metadata_list) + input_tokens, input_positions = self._prepare_inputs(seq_metadata_list) get_attention_wrapper().begin_forward(seq_metadata_list) @@ -228,7 +239,9 @@ def run( kv_caches=gpu_cache, ) except RuntimeError as e: - logger.error(f"RuntimeError: {e} for seq_metadata_list: {seq_metadata_list}") + logger.error( + f"RuntimeError: {e} for seq_metadata_list: {seq_metadata_list}" + ) raise e with self._sampler_e2e_timer: diff --git a/sarathi/model_executor/models/__init__.py b/sarathi/model_executor/models/__init__.py index 9b65e75..5eecd6e 100644 --- a/sarathi/model_executor/models/__init__.py +++ b/sarathi/model_executor/models/__init__.py @@ -1,6 +1,6 @@ from sarathi.model_executor.models.falcon import FalconForCausalLM -from sarathi.model_executor.models.llama import LlamaForCausalLM from sarathi.model_executor.models.internlm import InternLMForCausalLM +from sarathi.model_executor.models.llama import LlamaForCausalLM from sarathi.model_executor.models.mistral import MistralForCausalLM from sarathi.model_executor.models.qwen import QWenLMHeadModel from sarathi.model_executor.models.yi import YiForCausalLM diff --git a/sarathi/model_executor/models/falcon.py b/sarathi/model_executor/models/falcon.py index 227c867..c0cb7ed 100644 --- a/sarathi/model_executor/models/falcon.py +++ b/sarathi/model_executor/models/falcon.py @@ -26,31 +26,31 @@ from torch.nn import LayerNorm from transformers import FalconConfig as HF_FalconConfig -from sarathi.model_executor.weight_utils import (convert_pyslice_to_tensor, - hf_model_weights_iterator, - load_tensor_parallel_weights) -from sarathi.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from sarathi.model_executor.parallel_utils.tensor_parallel import ( - VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear, - reduce_from_tensor_model_parallel_region) -from sarathi.transformers_utils.configs import RWConfig +from sarathi.metrics.constants import OperationMetrics +from sarathi.metrics.cuda_timer import CudaTimer +from sarathi.model_executor.attention import get_attention_wrapper +from sarathi.model_executor.layers.rotary_embedding import get_rope from sarathi.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, get_pipeline_model_parallel_rank, get_pipeline_model_parallel_world_size, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, is_pipeline_first_stage, is_pipeline_last_stage, ) -from sarathi.model_executor.parallel_utils.pipeline_parallel.mappings import ( - send, - recv, +from sarathi.model_executor.parallel_utils.pipeline_parallel.mappings import recv, send +from sarathi.model_executor.parallel_utils.tensor_parallel import ( + ColumnParallelLinear, + RowParallelLinear, + VocabParallelEmbedding, + reduce_from_tensor_model_parallel_region, ) -from sarathi.metrics.constants import OperationMetrics -from sarathi.metrics.cuda_timer import CudaTimer -from sarathi.model_executor.attention import get_attention_wrapper -from sarathi.model_executor.layers.rotary_embedding import get_rope +from sarathi.model_executor.weight_utils import ( + convert_pyslice_to_tensor, + hf_model_weights_iterator, + load_tensor_parallel_weights, +) +from sarathi.transformers_utils.configs import RWConfig from sarathi.worker.cache_engine import KVCache FalconConfig = Union[HF_FalconConfig, RWConfig] @@ -92,15 +92,13 @@ def __init__(self, config: FalconConfig): self.num_kv_heads = self.total_num_kv_heads // tp_size self.query_key_value = ColumnParallelLinear( self.hidden_size, - (self.total_num_heads + 2 * self.total_num_kv_heads) * - self.head_dim, + (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim, bias=config.bias, gather_output=False, perform_initialization=False, skip_bias_add=True, linear_metric_name=OperationMetrics.ATTN_PRE_PROJ, - communication_metric_name=OperationMetrics. - ATTN_PRE_PROJ_ALL_GATHER, + communication_metric_name=OperationMetrics.ATTN_PRE_PROJ_ALL_GATHER, ) elif self.multi_query: self.total_num_kv_heads = 1 @@ -113,16 +111,15 @@ def __init__(self, config: FalconConfig): perform_initialization=False, skip_bias_add=True, ) - self.key_value = FalconLinear(self.hidden_size, - 2 * self.head_dim, - bias=config.bias) + self.key_value = FalconLinear( + self.hidden_size, 2 * self.head_dim, bias=config.bias + ) else: self.total_num_kv_heads = self.total_num_heads self.num_kv_heads = self.num_heads self.query_key_value = ColumnParallelLinear( self.hidden_size, - (self.total_num_heads + 2 * self.total_num_kv_heads) * - self.head_dim, + (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim, bias=config.bias, gather_output=False, perform_initialization=False, @@ -134,8 +131,9 @@ def __init__(self, config: FalconConfig): # Layer-wise attention scaling self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) - self.reduce_row_parallel_results = not (config.new_decoder_architecture - or config.parallel_attn) + self.reduce_row_parallel_results = not ( + config.new_decoder_architecture or config.parallel_attn + ) self.dense = RowParallelLinear( self.hidden_size, self.hidden_size, @@ -145,19 +143,18 @@ def __init__(self, config: FalconConfig): skip_bias_add=True, reduce_results=self.reduce_row_parallel_results, linear_metric_name=OperationMetrics.ATTN_POST_PROJ, - communication_metric_name=OperationMetrics. - ATTN_POST_PROJ_ALL_REDUCE, + communication_metric_name=OperationMetrics.ATTN_POST_PROJ_ALL_REDUCE, ) self.use_rotary = config.rotary self.use_alibi = config.alibi - assert not (self.use_rotary and self.use_alibi), ( - "Rotary and alibi are mutually exclusive.") + assert not ( + self.use_rotary and self.use_alibi + ), "Rotary and alibi are mutually exclusive." if self.use_rotary: rope_theta = getattr(config, "rope_theta", 10000) - max_position_embeddings = getattr(config, - "max_position_embeddings", 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) rope_scaling = getattr(config, "rope_scaling", None) self.rotary_emb = get_rope( head_size=self.head_dim, @@ -171,8 +168,7 @@ def __init__(self, config: FalconConfig): elif self.use_alibi: raise NotImplementedError("ALiBi is not yet supported.") else: - raise NotImplementedError( - "Standard attention is not yet supported.") + raise NotImplementedError("Standard attention is not yet supported.") def forward( self, @@ -190,8 +186,7 @@ def forward( qkv, bias = self.query_key_value(hidden_states) if bias is not None: qkv += bias - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], - dim=-1) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) if self.use_rotary: with self._attn_rope_timer: q, k = self.rotary_emb(positions, q, k) @@ -224,8 +219,9 @@ def __init__(self, config: FalconConfig): communication_metric_name=OperationMetrics.MLP_UP_PROJ_ALL_GATHER, ) self.act = nn.GELU() - self.reduce_row_parallel_results = not (config.new_decoder_architecture - or config.parallel_attn) + self.reduce_row_parallel_results = not ( + config.new_decoder_architecture or config.parallel_attn + ) self.dense_4h_to_h = RowParallelLinear( 4 * hidden_size, hidden_size, @@ -235,7 +231,7 @@ def __init__(self, config: FalconConfig): skip_bias_add=True, reduce_results=self.reduce_row_parallel_results, linear_metric_name=OperationMetrics.MLP_DOWN_PROJ, - communication_metric_name=OperationMetrics.MLP_DOWN_PROJ_ALL_REDUCE + communication_metric_name=OperationMetrics.MLP_DOWN_PROJ_ALL_REDUCE, ) self._mlp_activation_timer = CudaTimer(OperationMetrics.MLP_ACTIVATION) @@ -262,19 +258,19 @@ def __init__(self, config: FalconConfig): if config.new_decoder_architecture: # The layer norm before self-attention - self.ln_attn = LayerNorm(hidden_size, - eps=config.layer_norm_epsilon) + self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) # The layer norm before the MLP self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) else: - self.input_layernorm = LayerNorm(hidden_size, - eps=config.layer_norm_epsilon) + self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) if not config.parallel_attn: self.post_attention_layernorm = LayerNorm( - hidden_size, eps=config.layer_norm_epsilon) + hidden_size, eps=config.layer_norm_epsilon + ) - self.reduce_row_parallel_results = not (config.new_decoder_architecture - or config.parallel_attn) + self.reduce_row_parallel_results = not ( + config.new_decoder_architecture or config.parallel_attn + ) def forward( self, @@ -348,17 +344,19 @@ def __init__(self, config: FalconConfig): ) # Transformer blocks - self.h = nn.ModuleList([ - FalconDecoderLayer(config) - for _ in range(config.num_hidden_layers // - get_pipeline_model_parallel_world_size()) - ]) + self.h = nn.ModuleList( + [ + FalconDecoderLayer(config) + for _ in range( + config.num_hidden_layers // get_pipeline_model_parallel_world_size() + ) + ] + ) # Final Layer Norm self.ln_f = None if is_pipeline_last_stage(): - self.ln_f = LayerNorm(self.embed_dim, - eps=config.layer_norm_epsilon) + self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( self, @@ -394,11 +392,13 @@ def __init__(self, config: FalconConfig): self.lm_head = None if self.is_pipeline_last_stage: - self.lm_head = ColumnParallelLinear(config.hidden_size, - config.vocab_size, - bias=False, - gather_output=False, - perform_initialization=False) + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + bias=False, + gather_output=False, + perform_initialization=False, + ) def forward( self, @@ -423,17 +423,21 @@ def forward( return hidden_states _column_parallel_weights = [ - "word_embeddings.weight", "lm_head.weight", "dense_h_to_4h.weight", - "dense_h_to_4h.bias" + "word_embeddings.weight", + "lm_head.weight", + "dense_h_to_4h.weight", + "dense_h_to_4h.bias", ] _row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"] - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): - tp_size = (get_tensor_model_parallel_world_size()) + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() pp_size = get_pipeline_model_parallel_world_size() pp_rank = get_pipeline_model_parallel_rank() @@ -472,13 +476,13 @@ def load_weights(self, state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + model_name_or_path, cache_dir, load_format, revision + ): if pp_rank != 0 and "word_embeddings" in name: continue - if pp_rank != pp_size - 1 and ("lm_head" in name - or "ln_f" in name): + if pp_rank != pp_size - 1 and ("lm_head" in name or "ln_f" in name): continue if "transformer.h" in name: @@ -493,42 +497,51 @@ def load_weights(self, loaded_weight = convert_pyslice_to_tensor(loaded_weight) loaded_weight_size = loaded_weight.size() loaded_weight = loaded_weight.view( - total_num_kv_heads, num_query_heads_per_kv_head + 2, - head_size, *loaded_weight_size[1:]) + total_num_kv_heads, + num_query_heads_per_kv_head + 2, + head_size, + *loaded_weight_size[1:], + ) wq = loaded_weight[:, :-2].reshape(-1, *loaded_weight_size[1:]) - wk = loaded_weight[:, [-2]].reshape(-1, - *loaded_weight_size[1:]) - wv = loaded_weight[:, [-1]].reshape(-1, - *loaded_weight_size[1:]) + wk = loaded_weight[:, [-2]].reshape(-1, *loaded_weight_size[1:]) + wv = loaded_weight[:, [-1]].reshape(-1, *loaded_weight_size[1:]) - wq = wq[head_size * head_start:head_size * head_end] - wk = wk[head_size * kv_head_start:head_size * kv_head_end] - wv = wv[head_size * kv_head_start:head_size * kv_head_end] + wq = wq[head_size * head_start : head_size * head_end] + wk = wk[head_size * kv_head_start : head_size * kv_head_end] + wv = wv[head_size * kv_head_start : head_size * kv_head_end] if separated_q_kv: loaded_weight_q = wq loaded_weight_kv = torch.cat([wk, wv], dim=0) q_weight_name = name.replace("query_key_value", "query") - kv_weight_name = name.replace("query_key_value", - "key_value") - load_tensor_parallel_weights(state_dict[q_weight_name], - loaded_weight_q, - q_weight_name, - self._column_parallel_weights, - self._row_parallel_weights, - tp_rank) - load_tensor_parallel_weights(state_dict[kv_weight_name], - loaded_weight_kv, - kv_weight_name, - self._column_parallel_weights, - self._row_parallel_weights, - tp_rank) + kv_weight_name = name.replace("query_key_value", "key_value") + load_tensor_parallel_weights( + state_dict[q_weight_name], + loaded_weight_q, + q_weight_name, + self._column_parallel_weights, + self._row_parallel_weights, + tp_rank, + ) + load_tensor_parallel_weights( + state_dict[kv_weight_name], + loaded_weight_kv, + kv_weight_name, + self._column_parallel_weights, + self._row_parallel_weights, + tp_rank, + ) continue else: loaded_weight = torch.cat([wq, wk, wv], dim=0) param = state_dict[name] - load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights, - self._row_parallel_weights, tp_rank) + load_tensor_parallel_weights( + param, + loaded_weight, + name, + self._column_parallel_weights, + self._row_parallel_weights, + tp_rank, + ) diff --git a/sarathi/model_executor/models/internlm.py b/sarathi/model_executor/models/internlm.py index 3ab6872..07de881 100644 --- a/sarathi/model_executor/models/internlm.py +++ b/sarathi/model_executor/models/internlm.py @@ -7,23 +7,25 @@ from sarathi.metrics.constants import OperationMetrics from sarathi.metrics.cuda_timer import CudaTimer +from sarathi.model_executor.attention import get_attention_wrapper from sarathi.model_executor.layers.activation import SiluAndMul from sarathi.model_executor.layers.layernorm import RMSNorm -from sarathi.model_executor.weight_utils import ( - hf_model_weights_iterator, load_padded_tensor_parallel_vocab, - load_tensor_parallel_weights) +from sarathi.model_executor.layers.rotary_embedding import get_rope from sarathi.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from sarathi.model_executor.parallel_utils.pipeline_parallel.mappings import ( - send, - recv, -) +from sarathi.model_executor.parallel_utils.pipeline_parallel.mappings import recv, send from sarathi.model_executor.parallel_utils.tensor_parallel import ( - ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding) -from sarathi.model_executor.layers.rotary_embedding import get_rope -from sarathi.model_executor.attention import get_attention_wrapper + ColumnParallelLinear, + RowParallelLinear, + VocabParallelEmbedding, +) +from sarathi.model_executor.weight_utils import ( + hf_model_weights_iterator, + load_padded_tensor_parallel_vocab, + load_tensor_parallel_weights, +) from sarathi.worker.cache_engine import KVCache @@ -52,12 +54,13 @@ def __init__( input_is_parallel=True, perform_initialization=False, linear_metric_name=OperationMetrics.MLP_DOWN_PROJ, - communication_metric_name=OperationMetrics. - MLP_DOWN_PROJ_ALL_REDUCE, + communication_metric_name=OperationMetrics.MLP_DOWN_PROJ_ALL_REDUCE, ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) self.act_fn = SiluAndMul() self._mlp_activation_timer = CudaTimer(OperationMetrics.MLP_ACTIVATION) @@ -82,12 +85,10 @@ def __init__( ): super().__init__() self.hidden_size = hidden_size - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tensor_model_parallel_world_size == 0 - self.num_heads = (self.total_num_heads // - tensor_model_parallel_world_size) + self.num_heads = self.total_num_heads // tensor_model_parallel_world_size self.head_dim = hidden_size // self.total_num_heads self.scaling = self.head_dim**-0.5 @@ -98,8 +99,7 @@ def __init__( gather_output=False, perform_initialization=False, linear_metric_name=OperationMetrics.ATTN_PRE_PROJ, - communication_metric_name=OperationMetrics. - ATTN_PRE_PROJ_ALL_GATHER, + communication_metric_name=OperationMetrics.ATTN_PRE_PROJ_ALL_GATHER, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, @@ -108,8 +108,7 @@ def __init__( input_is_parallel=True, perform_initialization=False, linear_metric_name=OperationMetrics.ATTN_POST_PROJ, - communication_metric_name=OperationMetrics. - ATTN_POST_PROJ_ALL_REDUCE, + communication_metric_name=OperationMetrics.ATTN_POST_PROJ_ALL_REDUCE, ) self.rotary_emb = get_rope( @@ -152,8 +151,7 @@ def __init__(self, config: LlamaConfig): super().__init__() rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = InternLMAttention( hidden_size=config.hidden_size, num_heads=config.num_attention_heads, @@ -167,10 +165,10 @@ def __init__(self, config: LlamaConfig): intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -206,11 +204,11 @@ def __init__(self, config: LlamaConfig): vocab_size = ((config.vocab_size + 63) // 64) * 64 self.embed_tokens = VocabParallelEmbedding( - vocab_size, config.hidden_size, perform_initialization=False) - self.layers = nn.ModuleList([ - InternLMDecoderLayer(config) - for _ in range(config.num_hidden_layers) - ]) + vocab_size, config.hidden_size, perform_initialization=False + ) + self.layers = nn.ModuleList( + [InternLMDecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( @@ -238,11 +236,13 @@ def __init__(self, config): self.config = config self.model = InternLMModel(config) vocab_size = ((config.vocab_size + 63) // 64) * 64 - self.lm_head = ColumnParallelLinear(config.hidden_size, - vocab_size, - bias=False, - gather_output=False, - perform_initialization=False) + self.lm_head = ColumnParallelLinear( + config.hidden_size, + vocab_size, + bias=False, + gather_output=False, + perform_initialization=False, + ) def forward( self, @@ -253,42 +253,46 @@ def forward( hidden_states = self.model(hidden_states, positions, kv_caches) return hidden_states - _column_parallel_weights = [ - "qkv_proj.weight", "gate_proj.weight", "up_proj.weight" - ] + _column_parallel_weights = ["qkv_proj.weight", "gate_proj.weight", "up_proj.weight"] _row_parallel_weights = ["o_proj.weight", "down_proj.weight"] - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): tensor_model_parallel_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + model_name_or_path, cache_dir, load_format, revision + ): if "rotary_emb.inv_freq" in name: continue if "embed_tokens" in name or "lm_head" in name: param = state_dict[name] - load_padded_tensor_parallel_vocab(param, loaded_weight, - tensor_model_parallel_rank) + load_padded_tensor_parallel_vocab( + param, loaded_weight, tensor_model_parallel_rank + ) continue is_attention_weight = False - for stride_id, att_weight_name in enumerate( - ["q_proj", "k_proj", "v_proj"]): + for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]): if att_weight_name not in name: continue param = state_dict[name.replace(att_weight_name, "qkv_proj")] shard_size = param.shape[0] // 3 loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] + shard_size + * tensor_model_parallel_rank : shard_size + * (tensor_model_parallel_rank + 1) + ] + param_slice = param.data[ + shard_size * stride_id : shard_size * (stride_id + 1) + ] assert param_slice.shape == loaded_weight.shape param_slice.copy_(loaded_weight) is_attention_weight = True @@ -303,10 +307,13 @@ def load_weights(self, param = state_dict[name.replace(weight_name, "gate_up_proj")] shard_size = param.shape[0] // 2 loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] + shard_size + * tensor_model_parallel_rank : shard_size + * (tensor_model_parallel_rank + 1) + ] + param_slice = param.data[ + shard_size * stride_id : shard_size * (stride_id + 1) + ] assert param_slice.shape == loaded_weight.shape param_slice.copy_(loaded_weight) is_gate_up_weight = True @@ -315,7 +322,11 @@ def load_weights(self, continue param = state_dict[name] - load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights, - self._row_parallel_weights, - tensor_model_parallel_rank) + load_tensor_parallel_weights( + param, + loaded_weight, + name, + self._column_parallel_weights, + self._row_parallel_weights, + tensor_model_parallel_rank, + ) diff --git a/sarathi/model_executor/models/llama.py b/sarathi/model_executor/models/llama.py index 2761a40..eb180c4 100644 --- a/sarathi/model_executor/models/llama.py +++ b/sarathi/model_executor/models/llama.py @@ -32,27 +32,29 @@ from sarathi.metrics.constants import OperationMetrics from sarathi.metrics.cuda_timer import CudaTimer +from sarathi.model_executor.attention import get_attention_wrapper from sarathi.model_executor.layers.activation import SiluAndMul from sarathi.model_executor.layers.layernorm import RMSNorm -from sarathi.model_executor.attention import get_attention_wrapper +from sarathi.model_executor.layers.rotary_embedding import get_rope from sarathi.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, get_pipeline_model_parallel_rank, get_pipeline_model_parallel_world_size, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, is_pipeline_first_stage, is_pipeline_last_stage, ) -from sarathi.model_executor.parallel_utils.pipeline_parallel.mappings import ( - send, - recv, -) +from sarathi.model_executor.parallel_utils.pipeline_parallel.mappings import recv, send from sarathi.model_executor.parallel_utils.tensor_parallel import ( - VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) + ColumnParallelLinear, + RowParallelLinear, + VocabParallelEmbedding, +) from sarathi.model_executor.weight_utils import ( - hf_model_weights_iterator, load_tensor_parallel_weights, - load_padded_tensor_parallel_vocab) -from sarathi.model_executor.layers.rotary_embedding import get_rope + hf_model_weights_iterator, + load_padded_tensor_parallel_vocab, + load_tensor_parallel_weights, +) from sarathi.worker.cache_engine import KVCache @@ -74,7 +76,8 @@ def __init__( perform_initialization=False, linear_metric_name=OperationMetrics.MLP_UP_PROJ, communication_metric_name=OperationMetrics.MLP_UP_PROJ_ALL_GATHER, - layer_id=layer_id) + layer_id=layer_id, + ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, @@ -82,16 +85,19 @@ def __init__( input_is_parallel=True, perform_initialization=False, linear_metric_name=OperationMetrics.MLP_DOWN_PROJ, - communication_metric_name=OperationMetrics. - MLP_DOWN_PROJ_ALL_REDUCE, - layer_id=layer_id) + communication_metric_name=OperationMetrics.MLP_DOWN_PROJ_ALL_REDUCE, + layer_id=layer_id, + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) self.act_fn = SiluAndMul() - self._mlp_activation_timer = CudaTimer(OperationMetrics.MLP_ACTIVATION, - layer_id=layer_id) + self._mlp_activation_timer = CudaTimer( + OperationMetrics.MLP_ACTIVATION, layer_id=layer_id + ) def forward(self, x): gate_up, _ = self.gate_up_proj(x) @@ -103,14 +109,16 @@ def forward(self, x): class LlamaAttention(nn.Module): - def __init__(self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - max_position_embeddings: int = 8192, - layer_id: Optional[int] = None) -> None: + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + layer_id: Optional[int] = None, + ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -130,14 +138,12 @@ def __init__(self, self.qkv_proj = ColumnParallelLinear( hidden_size, - (self.total_num_heads + 2 * self.total_num_kv_heads) * - self.head_dim, + (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim, bias=False, gather_output=False, perform_initialization=False, linear_metric_name=OperationMetrics.ATTN_PRE_PROJ, - communication_metric_name=OperationMetrics. - ATTN_PRE_PROJ_ALL_GATHER, + communication_metric_name=OperationMetrics.ATTN_PRE_PROJ_ALL_GATHER, layer_id=layer_id, ) self.o_proj = RowParallelLinear( @@ -147,8 +153,7 @@ def __init__(self, input_is_parallel=True, perform_initialization=False, linear_metric_name=OperationMetrics.ATTN_POST_PROJ, - communication_metric_name=OperationMetrics. - ATTN_POST_PROJ_ALL_REDUCE, + communication_metric_name=OperationMetrics.ATTN_POST_PROJ_ALL_REDUCE, layer_id=layer_id, ) self.rotary_emb = get_rope( @@ -195,8 +200,7 @@ def __init__( # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = LlamaAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -216,12 +220,14 @@ def __init__( config.hidden_size, eps=config.rms_norm_eps, norm_name=OperationMetrics.INPUT_LAYERNORM, - layer_id=layer_id) + layer_id=layer_id, + ) self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps, norm_name=OperationMetrics.POST_ATTENTION_LAYERNORM, - layer_id=layer_id) + layer_id=layer_id, + ) def forward( self, @@ -269,13 +275,16 @@ def __init__( communication_metric_name=OperationMetrics.EMBED_ALL_REDUCE, ) - num_layers = config.num_hidden_layers // get_pipeline_model_parallel_world_size( + num_layers = ( + config.num_hidden_layers // get_pipeline_model_parallel_world_size() ) layer_offset = get_pipeline_model_parallel_rank() * num_layers - self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, layer_id=layer_id + layer_offset) - for layer_id in range(num_layers) - ]) + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer(config, layer_id=layer_id + layer_offset) + for layer_id in range(num_layers) + ] + ) self.norm = None if is_pipeline_last_stage(): @@ -320,11 +329,13 @@ def __init__( self.lm_head = None if self.is_pipeline_last_stage: - self.lm_head = ColumnParallelLinear(config.hidden_size, - vocab_size, - bias=False, - gather_output=False, - perform_initialization=False) + self.lm_head = ColumnParallelLinear( + config.hidden_size, + vocab_size, + bias=False, + gather_output=False, + perform_initialization=False, + ) def forward( self, @@ -351,11 +362,13 @@ def forward( _column_parallel_layers = [] _row_parallel_layers = ["o_proj", "down_proj"] - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): weight_suffixes = ["weight"] column_parallel_weights: List[str] = [] @@ -378,30 +391,33 @@ def load_weights(self, first_layer_id = layers_per_stage * pp_model_parallel_rank last_layer_id = layers_per_stage * (pp_model_parallel_rank + 1) - 1 - q_proj_shard_size = (self.config.hidden_size // tp_size) - kv_proj_shard_size = (self.config.hidden_size // - self.config.num_attention_heads * - self.config.num_key_value_heads // tp_size) + q_proj_shard_size = self.config.hidden_size // tp_size + kv_proj_shard_size = ( + self.config.hidden_size + // self.config.num_attention_heads + * self.config.num_key_value_heads + // tp_size + ) attention_weight_specs = [ # (weight_name, shard_size, offset) ("q_proj", q_proj_shard_size, 0), ("k_proj", kv_proj_shard_size, q_proj_shard_size), - ("v_proj", kv_proj_shard_size, - q_proj_shard_size + kv_proj_shard_size), + ("v_proj", kv_proj_shard_size, q_proj_shard_size + kv_proj_shard_size), ] state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + model_name_or_path, cache_dir, load_format, revision + ): if "rotary_emb.inv_freq" in name: continue - if pp_model_parallel_rank != 0 \ - and "embed_tokens" in name: + if pp_model_parallel_rank != 0 and "embed_tokens" in name: continue - if pp_model_parallel_rank != pp_size - 1 \ - and ("lm_head" in name or name == "model.norm.weight"): + if pp_model_parallel_rank != pp_size - 1 and ( + "lm_head" in name or name == "model.norm.weight" + ): continue if "model.layers" in name: @@ -419,9 +435,11 @@ def load_weights(self, param = state_dict[name.replace(weight_name, "qkv_proj")] loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[offset:offset + shard_size] + shard_size + * tensor_model_parallel_rank : shard_size + * (tensor_model_parallel_rank + 1) + ] + param_slice = param.data[offset : offset + shard_size] assert param_slice.shape == loaded_weight.shape param_slice.copy_(loaded_weight) @@ -438,10 +456,13 @@ def load_weights(self, shard_size = param.shape[0] // 2 loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] + shard_size + * tensor_model_parallel_rank : shard_size + * (tensor_model_parallel_rank + 1) + ] + param_slice = param.data[ + shard_size * stride_id : shard_size * (stride_id + 1) + ] assert param_slice.shape == loaded_weight.shape param_slice.copy_(loaded_weight) is_gate_up_weight = True @@ -452,11 +473,16 @@ def load_weights(self, param = state_dict[name] if "embed_tokens" in name or "lm_head" in name: - load_padded_tensor_parallel_vocab(param, loaded_weight, - tensor_model_parallel_rank) + load_padded_tensor_parallel_vocab( + param, loaded_weight, tensor_model_parallel_rank + ) continue - load_tensor_parallel_weights(param, loaded_weight, name, - column_parallel_weights, - row_parallel_weights, - tensor_model_parallel_rank) + load_tensor_parallel_weights( + param, + loaded_weight, + name, + column_parallel_weights, + row_parallel_weights, + tensor_model_parallel_rank, + ) diff --git a/sarathi/model_executor/models/mistral.py b/sarathi/model_executor/models/mistral.py index f10ff03..c6b1463 100644 --- a/sarathi/model_executor/models/mistral.py +++ b/sarathi/model_executor/models/mistral.py @@ -24,35 +24,37 @@ The input of the model is flattened to a 1D tensor of tokens. """ -from typing import List, Optional, Any, Dict +from typing import Any, Dict, List, Optional import torch from torch import nn from transformers import MistralConfig +from sarathi.metrics.constants import OperationMetrics +from sarathi.metrics.cuda_timer import CudaTimer +from sarathi.model_executor.attention import get_attention_wrapper from sarathi.model_executor.layers.activation import SiluAndMul from sarathi.model_executor.layers.layernorm import RMSNorm -from sarathi.model_executor.attention import get_attention_wrapper from sarathi.model_executor.layers.rotary_embedding import get_rope from sarathi.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, get_pipeline_model_parallel_rank, get_pipeline_model_parallel_world_size, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, is_pipeline_first_stage, is_pipeline_last_stage, ) -from sarathi.model_executor.parallel_utils.pipeline_parallel.mappings import ( - send, - recv, -) +from sarathi.model_executor.parallel_utils.pipeline_parallel.mappings import recv, send from sarathi.model_executor.parallel_utils.tensor_parallel import ( - VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) + ColumnParallelLinear, + RowParallelLinear, + VocabParallelEmbedding, +) from sarathi.model_executor.weight_utils import ( - hf_model_weights_iterator, load_tensor_parallel_weights, - load_padded_tensor_parallel_vocab) -from sarathi.metrics.constants import OperationMetrics -from sarathi.metrics.cuda_timer import CudaTimer + hf_model_weights_iterator, + load_padded_tensor_parallel_vocab, + load_tensor_parallel_weights, +) from sarathi.worker.cache_engine import KVCache @@ -72,7 +74,8 @@ def __init__( gather_output=False, perform_initialization=False, linear_metric_name=OperationMetrics.MLP_UP_PROJ, - communication_metric_name=OperationMetrics.MLP_UP_PROJ_ALL_GATHER) + communication_metric_name=OperationMetrics.MLP_UP_PROJ_ALL_GATHER, + ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, @@ -80,11 +83,13 @@ def __init__( input_is_parallel=True, perform_initialization=False, linear_metric_name=OperationMetrics.MLP_DOWN_PROJ, - communication_metric_name=OperationMetrics.MLP_DOWN_PROJ_ALL_REDUCE + communication_metric_name=OperationMetrics.MLP_DOWN_PROJ_ALL_REDUCE, ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) self.act_fn = SiluAndMul() self._mlp_activation_timer = CudaTimer(OperationMetrics.MLP_ACTIVATION) @@ -99,13 +104,15 @@ def forward(self, x): class MistralAttention(nn.Module): - def __init__(self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None) -> None: + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -123,14 +130,12 @@ def __init__(self, self.qkv_proj = ColumnParallelLinear( hidden_size, - (self.total_num_heads + 2 * self.total_num_kv_heads) * - self.head_dim, + (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim, bias=False, gather_output=False, perform_initialization=False, linear_metric_name=OperationMetrics.ATTN_PRE_PROJ, - communication_metric_name=OperationMetrics. - ATTN_PRE_PROJ_ALL_GATHER, + communication_metric_name=OperationMetrics.ATTN_PRE_PROJ_ALL_GATHER, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, @@ -139,8 +144,7 @@ def __init__(self, input_is_parallel=True, perform_initialization=False, linear_metric_name=OperationMetrics.ATTN_POST_PROJ, - communication_metric_name=OperationMetrics. - ATTN_POST_PROJ_ALL_REDUCE, + communication_metric_name=OperationMetrics.ATTN_POST_PROJ_ALL_REDUCE, ) self.rotary_emb = get_rope( head_size=self.head_dim, @@ -190,16 +194,17 @@ def __init__( max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, - rope_scaling=rope_scaling) + rope_scaling=rope_scaling, + ) self.mlp = MistralMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -247,11 +252,14 @@ def __init__( communication_metric_name=OperationMetrics.EMBED_ALL_REDUCE, ) - self.layers = nn.ModuleList([ - MistralDecoderLayer(config) - for _ in range(config.num_hidden_layers // - get_pipeline_model_parallel_world_size()) - ]) + self.layers = nn.ModuleList( + [ + MistralDecoderLayer(config) + for _ in range( + config.num_hidden_layers // get_pipeline_model_parallel_world_size() + ) + ] + ) self.norm = None if is_pipeline_last_stage(): @@ -294,11 +302,13 @@ def __init__( self.lm_head = None if self.is_pipeline_last_stage: - self.lm_head = ColumnParallelLinear(config.hidden_size, - vocab_size, - bias=False, - gather_output=False, - perform_initialization=False) + self.lm_head = ColumnParallelLinear( + config.hidden_size, + vocab_size, + bias=False, + gather_output=False, + perform_initialization=False, + ) def forward( self, @@ -325,11 +335,13 @@ def forward( _column_parallel_layers = [] _row_parallel_layers = ["o_proj", "down_proj"] - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): weight_suffixes = ["weight"] column_parallel_weights: List[str] = [] @@ -352,30 +364,33 @@ def load_weights(self, first_layer_id = layers_per_stage * pp_model_parallel_rank last_layer_id = layers_per_stage * (pp_model_parallel_rank + 1) - 1 - q_proj_shard_size = (self.config.hidden_size // tp_size) - kv_proj_shard_size = (self.config.hidden_size // - self.config.num_attention_heads * - self.config.num_key_value_heads // tp_size) + q_proj_shard_size = self.config.hidden_size // tp_size + kv_proj_shard_size = ( + self.config.hidden_size + // self.config.num_attention_heads + * self.config.num_key_value_heads + // tp_size + ) attention_weight_specs = [ # (weight_name, shard_size, offset) ("q_proj", q_proj_shard_size, 0), ("k_proj", kv_proj_shard_size, q_proj_shard_size), - ("v_proj", kv_proj_shard_size, - q_proj_shard_size + kv_proj_shard_size), + ("v_proj", kv_proj_shard_size, q_proj_shard_size + kv_proj_shard_size), ] state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + model_name_or_path, cache_dir, load_format, revision + ): if "rotary_emb.inv_freq" in name: continue - if pp_model_parallel_rank != 0 \ - and "embed_tokens" in name: + if pp_model_parallel_rank != 0 and "embed_tokens" in name: continue - if pp_model_parallel_rank != pp_size - 1 \ - and ("lm_head" in name or name == "model.norm.weight"): + if pp_model_parallel_rank != pp_size - 1 and ( + "lm_head" in name or name == "model.norm.weight" + ): continue if "model.layers" in name: @@ -393,9 +408,11 @@ def load_weights(self, param = state_dict[name.replace(weight_name, "qkv_proj")] loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[offset:offset + shard_size] + shard_size + * tensor_model_parallel_rank : shard_size + * (tensor_model_parallel_rank + 1) + ] + param_slice = param.data[offset : offset + shard_size] assert param_slice.shape == loaded_weight.shape param_slice.copy_(loaded_weight) @@ -412,10 +429,13 @@ def load_weights(self, shard_size = param.shape[0] // 2 loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] + shard_size + * tensor_model_parallel_rank : shard_size + * (tensor_model_parallel_rank + 1) + ] + param_slice = param.data[ + shard_size * stride_id : shard_size * (stride_id + 1) + ] assert param_slice.shape == loaded_weight.shape param_slice.copy_(loaded_weight) is_gate_up_weight = True @@ -426,11 +446,16 @@ def load_weights(self, param = state_dict[name] if "embed_tokens" in name or "lm_head" in name: - load_padded_tensor_parallel_vocab(param, loaded_weight, - tensor_model_parallel_rank) + load_padded_tensor_parallel_vocab( + param, loaded_weight, tensor_model_parallel_rank + ) continue - load_tensor_parallel_weights(param, loaded_weight, name, - column_parallel_weights, - row_parallel_weights, - tensor_model_parallel_rank) + load_tensor_parallel_weights( + param, + loaded_weight, + name, + column_parallel_weights, + row_parallel_weights, + tensor_model_parallel_rank, + ) diff --git a/sarathi/model_executor/models/qwen.py b/sarathi/model_executor/models/qwen.py index c9f5134..48f88dd 100644 --- a/sarathi/model_executor/models/qwen.py +++ b/sarathi/model_executor/models/qwen.py @@ -14,34 +14,31 @@ from sarathi.metrics.constants import OperationMetrics from sarathi.metrics.cuda_timer import CudaTimer +from sarathi.model_executor.attention import get_attention_wrapper from sarathi.model_executor.layers.activation import SiluAndMul from sarathi.model_executor.layers.layernorm import RMSNorm -from sarathi.model_executor.weight_utils import ( - convert_pyslice_to_tensor, - hf_model_weights_iterator, - load_padded_tensor_parallel_vocab, - load_tensor_parallel_weights, -) +from sarathi.model_executor.layers.rotary_embedding import get_rope from sarathi.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, get_pipeline_model_parallel_rank, get_pipeline_model_parallel_world_size, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, is_pipeline_first_stage, is_pipeline_last_stage, ) -from sarathi.model_executor.parallel_utils.pipeline_parallel.mappings import ( - send, - recv, -) +from sarathi.model_executor.parallel_utils.pipeline_parallel.mappings import recv, send from sarathi.model_executor.parallel_utils.tensor_parallel import ( - VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear, + VocabParallelEmbedding, +) +from sarathi.model_executor.weight_utils import ( + convert_pyslice_to_tensor, + hf_model_weights_iterator, + load_padded_tensor_parallel_vocab, + load_tensor_parallel_weights, ) from sarathi.transformers_utils.configs.qwen import QWenConfig -from sarathi.model_executor.layers.rotary_embedding import get_rope -from sarathi.model_executor.attention import get_attention_wrapper from sarathi.worker.cache_engine import KVCache @@ -70,12 +67,13 @@ def __init__( input_is_parallel=True, perform_initialization=False, linear_metric_name=OperationMetrics.MLP_DOWN_PROJ, - communication_metric_name=OperationMetrics. - MLP_DOWN_PROJ_ALL_REDUCE, + communication_metric_name=OperationMetrics.MLP_DOWN_PROJ_ALL_REDUCE, ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) self.act_fn = SiluAndMul() self._mlp_activation_timer = CudaTimer(OperationMetrics.MLP_ACTIVATION) @@ -99,12 +97,10 @@ def __init__( ): super().__init__() self.hidden_size = hidden_size - tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( - ) + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tensor_model_parallel_world_size == 0 - self.num_heads = (self.total_num_heads // - tensor_model_parallel_world_size) + self.num_heads = self.total_num_heads // tensor_model_parallel_world_size self.head_dim = hidden_size // self.total_num_heads # pylint: disable=invalid-name @@ -115,8 +111,7 @@ def __init__( gather_output=False, perform_initialization=False, linear_metric_name=OperationMetrics.ATTN_PRE_PROJ, - communication_metric_name=OperationMetrics. - ATTN_PRE_PROJ_ALL_GATHER, + communication_metric_name=OperationMetrics.ATTN_PRE_PROJ_ALL_GATHER, ) self.c_proj = RowParallelLinear( self.total_num_heads * self.head_dim, @@ -125,8 +120,7 @@ def __init__( input_is_parallel=True, perform_initialization=False, linear_metric_name=OperationMetrics.ATTN_POST_PROJ, - communication_metric_name=OperationMetrics. - ATTN_POST_PROJ_ALL_REDUCE, + communication_metric_name=OperationMetrics.ATTN_POST_PROJ_ALL_REDUCE, ) self.scaling = self.head_dim**-0.5 @@ -222,18 +216,20 @@ def __init__(self, config: QWenConfig): if is_pipeline_first_stage(): vocab_size = ((config.vocab_size + 63) // 64) * 64 - self.wte = VocabParallelEmbedding(vocab_size, - config.hidden_size, - perform_initialization=False) - self.h = nn.ModuleList([ - QWenBlock(config) - for _ in range(config.num_hidden_layers // - get_pipeline_model_parallel_world_size()) - ]) + self.wte = VocabParallelEmbedding( + vocab_size, config.hidden_size, perform_initialization=False + ) + self.h = nn.ModuleList( + [ + QWenBlock(config) + for _ in range( + config.num_hidden_layers // get_pipeline_model_parallel_world_size() + ) + ] + ) self.ln_f = None if is_pipeline_last_stage(): - self.ln_f = RMSNorm(config.hidden_size, - eps=config.layer_norm_epsilon) + self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) def forward( self, @@ -323,15 +319,15 @@ def load_weights( last_layer_id = layers_per_stage * (pp_rank + 1) - 1 for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + model_name_or_path, cache_dir, load_format, revision + ): if "rotary_emb.inv_freq" in name: continue if pp_rank != 0 and "wte" in name: continue - if pp_rank != pp_world_size - 1 \ - and ("lm_head" in name or "ln_f" in name): + if pp_rank != pp_world_size - 1 and ("lm_head" in name or "ln_f" in name): continue loaded_weight = convert_pyslice_to_tensor(loaded_weight) @@ -353,13 +349,13 @@ def load_weights( head_end = (tp_rank + 1) * num_heads if "weight" in name: - loaded_weight = loaded_weight.view(3, total_num_heads, - head_size, hidden_size) + loaded_weight = loaded_weight.view( + 3, total_num_heads, head_size, hidden_size + ) loaded_weight = loaded_weight[:, head_start:head_end, :, :] loaded_weight = loaded_weight.reshape(-1, hidden_size) elif "bias" in name: - loaded_weight = loaded_weight.view(3, total_num_heads, - head_size) + loaded_weight = loaded_weight.view(3, total_num_heads, head_size) loaded_weight = loaded_weight[:, head_start:head_end, :] loaded_weight = loaded_weight.reshape(-1) @@ -369,10 +365,12 @@ def load_weights( continue param = state_dict[name.replace(weight_name, "gate_up_proj")] shard_size = param.shape[0] // 2 - loaded_weight = loaded_weight[shard_size * tp_rank:shard_size * - (tp_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] + loaded_weight = loaded_weight[ + shard_size * tp_rank : shard_size * (tp_rank + 1) + ] + param_slice = param.data[ + shard_size * stride_id : shard_size * (stride_id + 1) + ] assert param_slice.shape == loaded_weight.shape param_slice.copy_(loaded_weight) is_gate_up_weight = True @@ -383,8 +381,7 @@ def load_weights( param = state_dict[name] if "wte" in name or "lm_head" in name: - load_padded_tensor_parallel_vocab(param, loaded_weight, - tp_rank) + load_padded_tensor_parallel_vocab(param, loaded_weight, tp_rank) continue load_tensor_parallel_weights( diff --git a/sarathi/model_executor/models/yi.py b/sarathi/model_executor/models/yi.py index c3fb118..137a3c1 100644 --- a/sarathi/model_executor/models/yi.py +++ b/sarathi/model_executor/models/yi.py @@ -28,31 +28,33 @@ import torch from torch import nn -from sarathi.transformers_utils.configs.yi import YiConfig +from sarathi.metrics.constants import OperationMetrics +from sarathi.metrics.cuda_timer import CudaTimer +from sarathi.model_executor.attention import get_attention_wrapper from sarathi.model_executor.layers.activation import SiluAndMul from sarathi.model_executor.layers.layernorm import RMSNorm -from sarathi.model_executor.attention import get_attention_wrapper +from sarathi.model_executor.layers.rotary_embedding import get_rope from sarathi.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, get_pipeline_model_parallel_rank, get_pipeline_model_parallel_world_size, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, is_pipeline_first_stage, is_pipeline_last_stage, ) -from sarathi.model_executor.parallel_utils.pipeline_parallel.mappings import ( - send, - recv, -) +from sarathi.model_executor.parallel_utils.pipeline_parallel.mappings import recv, send from sarathi.model_executor.parallel_utils.tensor_parallel import ( - VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) + ColumnParallelLinear, + RowParallelLinear, + VocabParallelEmbedding, +) from sarathi.model_executor.weight_utils import ( - hf_model_weights_iterator, load_tensor_parallel_weights, - load_padded_tensor_parallel_vocab) -from sarathi.metrics.constants import OperationMetrics -from sarathi.metrics.cuda_timer import CudaTimer -from sarathi.model_executor.layers.rotary_embedding import get_rope + hf_model_weights_iterator, + load_padded_tensor_parallel_vocab, + load_tensor_parallel_weights, +) +from sarathi.transformers_utils.configs.yi import YiConfig from sarathi.worker.cache_engine import KVCache @@ -71,18 +73,21 @@ def __init__( bias=False, gather_output=False, linear_metric_name=OperationMetrics.MLP_UP_PROJ, - communication_metric_name=OperationMetrics.MLP_UP_PROJ_ALL_GATHER) + communication_metric_name=OperationMetrics.MLP_UP_PROJ_ALL_GATHER, + ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, input_is_parallel=True, linear_metric_name=OperationMetrics.MLP_DOWN_PROJ, - communication_metric_name=OperationMetrics.MLP_DOWN_PROJ_ALL_REDUCE + communication_metric_name=OperationMetrics.MLP_DOWN_PROJ_ALL_REDUCE, ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) self.act_fn = SiluAndMul() self._mlp_activation_timer = CudaTimer(OperationMetrics.MLP_ACTIVATION) @@ -131,14 +136,12 @@ def __init__( self.qkv_proj = ColumnParallelLinear( hidden_size, - (self.total_num_heads + - 2 * self.total_num_kv_heads * num_kv_heads_replicas) * - self.head_dim, + (self.total_num_heads + 2 * self.total_num_kv_heads * num_kv_heads_replicas) + * self.head_dim, bias=False, gather_output=False, linear_metric_name=OperationMetrics.ATTN_PRE_PROJ, - communication_metric_name=OperationMetrics. - ATTN_PRE_PROJ_ALL_GATHER, + communication_metric_name=OperationMetrics.ATTN_PRE_PROJ_ALL_GATHER, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, @@ -146,8 +149,7 @@ def __init__( bias=False, input_is_parallel=True, linear_metric_name=OperationMetrics.ATTN_POST_PROJ, - communication_metric_name=OperationMetrics. - ATTN_POST_PROJ_ALL_REDUCE, + communication_metric_name=OperationMetrics.ATTN_POST_PROJ_ALL_REDUCE, ) self.rotary_emb = get_rope( head_size=self.num_heads, @@ -191,8 +193,7 @@ def __init__( # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = YiAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -253,11 +254,14 @@ def __init__( linear_metric_name=OperationMetrics.EMBED_LINEAR, communication_metric_name=OperationMetrics.EMBED_ALL_REDUCE, ) - self.layers = nn.ModuleList([ - YiDecoderLayer(config) - for _ in range(config.num_hidden_layers // - get_pipeline_model_parallel_world_size()) - ]) + self.layers = nn.ModuleList( + [ + YiDecoderLayer(config) + for _ in range( + config.num_hidden_layers // get_pipeline_model_parallel_world_size() + ) + ] + ) self.norm = None if is_pipeline_last_stage(): self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -299,10 +303,9 @@ def __init__( self.lm_head = None if self.is_pipeline_last_stage: - self.lm_head = ColumnParallelLinear(config.hidden_size, - vocab_size, - bias=False, - gather_output=False) + self.lm_head = ColumnParallelLinear( + config.hidden_size, vocab_size, bias=False, gather_output=False + ) def forward( self, @@ -329,11 +332,13 @@ def forward( _column_parallel_layers = [] _row_parallel_layers = ["o_proj", "down_proj"] - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): weight_suffixes = ["weight"] weight_suffixes = ["weight"] @@ -358,34 +363,34 @@ def load_weights(self, first_layer_id = layers_per_stage * pp_rank last_layer_id = layers_per_stage * (pp_rank + 1) - 1 - q_proj_shard_size = (self.config.hidden_size // tp_size) - num_kv_heads_replicas = max(1, - tp_size // self.config.num_key_value_heads) - num_kv_heads_per_gpu = max(1, - self.config.num_key_value_heads // tp_size) - kv_proj_shard_size = (self.config.hidden_size // - self.config.num_attention_heads * - num_kv_heads_per_gpu) + q_proj_shard_size = self.config.hidden_size // tp_size + num_kv_heads_replicas = max(1, tp_size // self.config.num_key_value_heads) + num_kv_heads_per_gpu = max(1, self.config.num_key_value_heads // tp_size) + kv_proj_shard_size = ( + self.config.hidden_size + // self.config.num_attention_heads + * num_kv_heads_per_gpu + ) attention_weight_specs = [ # (weight_name, shard_size, offset) ("q_proj", q_proj_shard_size, 0), ("k_proj", kv_proj_shard_size, q_proj_shard_size), - ("v_proj", kv_proj_shard_size, - q_proj_shard_size + kv_proj_shard_size), + ("v_proj", kv_proj_shard_size, q_proj_shard_size + kv_proj_shard_size), ] state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + model_name_or_path, cache_dir, load_format, revision + ): if "rotary_emb.inv_freq" in name: continue - if pp_rank != 0 \ - and "embed_tokens" in name: + if pp_rank != 0 and "embed_tokens" in name: continue - if pp_rank != pp_size - 1 \ - and ("lm_head" in name or name == "model.norm.weight"): + if pp_rank != pp_size - 1 and ( + "lm_head" in name or name == "model.norm.weight" + ): continue if "model.layers" in name: @@ -405,10 +410,10 @@ def load_weights(self, shard_id = tp_rank // num_kv_heads_replicas else: shard_id = tp_rank - loaded_weight = loaded_weight[shard_size * - shard_id:shard_size * - (shard_id + 1)] - param_slice = param.data[offset:offset + shard_size] + loaded_weight = loaded_weight[ + shard_size * shard_id : shard_size * (shard_id + 1) + ] + param_slice = param.data[offset : offset + shard_size] assert param_slice.shape == loaded_weight.shape param_slice.copy_(loaded_weight) @@ -424,10 +429,12 @@ def load_weights(self, param = state_dict[name.replace(weight_name, "gate_up_proj")] shard_size = param.shape[0] // 2 - loaded_weight = loaded_weight[shard_size * tp_rank:shard_size * - (tp_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] + loaded_weight = loaded_weight[ + shard_size * tp_rank : shard_size * (tp_rank + 1) + ] + param_slice = param.data[ + shard_size * stride_id : shard_size * (stride_id + 1) + ] assert param_slice.shape == loaded_weight.shape param_slice.copy_(loaded_weight) is_gate_up_weight = True @@ -438,10 +445,14 @@ def load_weights(self, param = state_dict[name] if "embed_tokens" in name or "lm_head" in name: - load_padded_tensor_parallel_vocab(param, loaded_weight, - tp_rank) + load_padded_tensor_parallel_vocab(param, loaded_weight, tp_rank) continue - load_tensor_parallel_weights(param, loaded_weight, name, - column_parallel_weights, - row_parallel_weights, tp_rank) + load_tensor_parallel_weights( + param, + loaded_weight, + name, + column_parallel_weights, + row_parallel_weights, + tp_rank, + ) diff --git a/sarathi/model_executor/parallel_utils/parallel_state.py b/sarathi/model_executor/parallel_utils/parallel_state.py index a8a815a..818d20c 100644 --- a/sarathi/model_executor/parallel_utils/parallel_state.py +++ b/sarathi/model_executor/parallel_utils/parallel_state.py @@ -4,9 +4,10 @@ """Model and data parallel groups.""" -import torch from typing import Optional +import torch + # Intra-layer model parallel group that the current rank belongs to. _TENSOR_MODEL_PARALLEL_GROUP = None # Inter-layer model parallel group that the current rank belongs to. @@ -88,21 +89,26 @@ def initialize_model_parallel( f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})" ) - data_parallel_size: int = world_size // (tensor_model_parallel_size * - pipeline_model_parallel_size) + data_parallel_size: int = world_size // ( + tensor_model_parallel_size * pipeline_model_parallel_size + ) - num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size + num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size num_data_parallel_groups: int = world_size // data_parallel_size if virtual_pipeline_model_parallel_size is not None: if not pipeline_model_parallel_size > 2: - raise RuntimeError("pipeline-model-parallel size should be greater than 2 with " - "interleaved schedule") + raise RuntimeError( + "pipeline-model-parallel size should be greater than 2 with " + "interleaved schedule" + ) global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0 - _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size + _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = ( + virtual_pipeline_model_parallel_size + ) if pipeline_model_parallel_split_rank is not None: global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK @@ -113,7 +119,7 @@ def initialize_model_parallel( # Build the data-parallel groups. global _DATA_PARALLEL_GROUP global _DATA_PARALLEL_GLOBAL_RANKS - assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized' + assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized" all_data_parallel_group_ranks = [] for i in range(pipeline_model_parallel_size): start_rank = i * num_pipeline_model_parallel_groups @@ -128,21 +134,25 @@ def initialize_model_parallel( # Build the model-parallel groups. global _MODEL_PARALLEL_GROUP - assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized' + assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized" for i in range(data_parallel_size): - ranks = [data_parallel_group_ranks[i] - for data_parallel_group_ranks in all_data_parallel_group_ranks] + ranks = [ + data_parallel_group_ranks[i] + for data_parallel_group_ranks in all_data_parallel_group_ranks + ] group = torch.distributed.new_group(ranks) if rank in ranks: _MODEL_PARALLEL_GROUP = group # Build the tensor model-parallel groups. global _TENSOR_MODEL_PARALLEL_GROUP - assert _TENSOR_MODEL_PARALLEL_GROUP is None, \ - 'tensor model parallel group is already initialized' + assert ( + _TENSOR_MODEL_PARALLEL_GROUP is None + ), "tensor model parallel group is already initialized" for i in range(num_tensor_model_parallel_groups): - ranks = range(i * tensor_model_parallel_size, - (i + 1) * tensor_model_parallel_size) + ranks = range( + i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size + ) group = torch.distributed.new_group(ranks) if rank in ranks: _TENSOR_MODEL_PARALLEL_GROUP = group @@ -151,15 +161,17 @@ def initialize_model_parallel( # (first and last rank in each pipeline model-parallel group). global _PIPELINE_MODEL_PARALLEL_GROUP global _PIPELINE_GLOBAL_RANKS - assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \ - 'pipeline model parallel group is already initialized' + assert ( + _PIPELINE_MODEL_PARALLEL_GROUP is None + ), "pipeline model parallel group is already initialized" global _EMBEDDING_GROUP global _EMBEDDING_GLOBAL_RANKS - assert _EMBEDDING_GROUP is None, 'embedding group is already initialized' + assert _EMBEDDING_GROUP is None, "embedding group is already initialized" global _POSITION_EMBEDDING_GROUP global _POSITION_EMBEDDING_GLOBAL_RANKS - assert _POSITION_EMBEDDING_GROUP is None, \ - 'position embedding group is already initialized' + assert ( + _POSITION_EMBEDDING_GROUP is None + ), "position embedding group is already initialized" for i in range(num_pipeline_model_parallel_groups): ranks = range(i, world_size, num_pipeline_model_parallel_groups) group = torch.distributed.new_group(ranks) @@ -173,12 +185,19 @@ def initialize_model_parallel( position_embedding_ranks = [ranks[0]] if pipeline_model_parallel_split_rank is not None: if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks: - embedding_ranks = [ranks[0], - ranks[pipeline_model_parallel_split_rank], - ranks[-1]] - if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks: - position_embedding_ranks = [ranks[0], - ranks[pipeline_model_parallel_split_rank]] + embedding_ranks = [ + ranks[0], + ranks[pipeline_model_parallel_split_rank], + ranks[-1], + ] + if ( + ranks[pipeline_model_parallel_split_rank] + not in position_embedding_ranks + ): + position_embedding_ranks = [ + ranks[0], + ranks[pipeline_model_parallel_split_rank], + ] else: embedding_ranks = ranks position_embedding_ranks = ranks @@ -195,54 +214,57 @@ def initialize_model_parallel( if rank in ranks: _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks + def model_parallel_is_initialized(): """Check if model and data parallel groups are initialized.""" - if _TENSOR_MODEL_PARALLEL_GROUP is None or \ - _PIPELINE_MODEL_PARALLEL_GROUP is None or \ - _DATA_PARALLEL_GROUP is None: + if ( + _TENSOR_MODEL_PARALLEL_GROUP is None + or _PIPELINE_MODEL_PARALLEL_GROUP is None + or _DATA_PARALLEL_GROUP is None + ): return False return True def get_model_parallel_group(): """Get the model parallel group the caller rank belongs to.""" - assert _MODEL_PARALLEL_GROUP is not None, \ - 'model parallel group is not initialized' + assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized" return _MODEL_PARALLEL_GROUP def get_tensor_model_parallel_group(): """Get the tensor model parallel group the caller rank belongs to.""" - assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \ - 'intra_layer_model parallel group is not initialized' + assert ( + _TENSOR_MODEL_PARALLEL_GROUP is not None + ), "intra_layer_model parallel group is not initialized" return _TENSOR_MODEL_PARALLEL_GROUP def get_pipeline_model_parallel_group(): """Get the pipeline model parallel group the caller rank belongs to.""" - assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, \ - 'pipeline_model parallel group is not initialized' + assert ( + _PIPELINE_MODEL_PARALLEL_GROUP is not None + ), "pipeline_model parallel group is not initialized" return _PIPELINE_MODEL_PARALLEL_GROUP def get_data_parallel_group(): """Get the data parallel group the caller rank belongs to.""" - assert _DATA_PARALLEL_GROUP is not None, \ - 'data parallel group is not initialized' + assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized" return _DATA_PARALLEL_GROUP def get_embedding_group(): """Get the embedding group the caller rank belongs to.""" - assert _EMBEDDING_GROUP is not None, \ - 'embedding group is not initialized' + assert _EMBEDDING_GROUP is not None, "embedding group is not initialized" return _EMBEDDING_GROUP def get_position_embedding_group(): """Get the position embedding group the caller rank belongs to.""" - assert _POSITION_EMBEDDING_GROUP is not None, \ - 'position embedding group is not initialized' + assert ( + _POSITION_EMBEDDING_GROUP is not None + ), "position embedding group is not initialized" return _POSITION_EMBEDDING_GROUP @@ -308,12 +330,13 @@ def get_pipeline_model_parallel_rank(): return torch.distributed.get_rank(group=get_pipeline_model_parallel_group()) - def is_pipeline_first_stage(ignore_virtual=False): """Return True if in the first pipeline model-parallel stage, False otherwise.""" if not ignore_virtual: - if get_virtual_pipeline_model_parallel_world_size() is not None and \ - get_virtual_pipeline_model_parallel_rank() != 0: + if ( + get_virtual_pipeline_model_parallel_world_size() is not None + and get_virtual_pipeline_model_parallel_rank() != 0 + ): return False return get_pipeline_model_parallel_rank() == 0 @@ -321,14 +344,18 @@ def is_pipeline_first_stage(ignore_virtual=False): def is_pipeline_last_stage(ignore_virtual=False): """Return True if in the last pipeline model-parallel stage, False otherwise.""" if not ignore_virtual: - virtual_pipeline_model_parallel_world_size = \ + virtual_pipeline_model_parallel_world_size = ( get_virtual_pipeline_model_parallel_world_size() - if virtual_pipeline_model_parallel_world_size is not None and \ - get_virtual_pipeline_model_parallel_rank() != ( - virtual_pipeline_model_parallel_world_size - 1): + ) + if ( + virtual_pipeline_model_parallel_world_size is not None + and get_virtual_pipeline_model_parallel_rank() + != (virtual_pipeline_model_parallel_world_size - 1) + ): return False return get_pipeline_model_parallel_rank() == ( - get_pipeline_model_parallel_world_size() - 1) + get_pipeline_model_parallel_world_size() - 1 + ) def is_rank_in_embedding_group(ignore_virtual=False): @@ -389,8 +416,9 @@ def is_pipeline_stage_at_split(): stage executes encoder block for a model with both encoder and decoder.""" rank = get_pipeline_model_parallel_rank() - return is_pipeline_stage_before_split(rank) and \ - is_pipeline_stage_after_split(rank+1) + return is_pipeline_stage_before_split(rank) and is_pipeline_stage_after_split( + rank + 1 + ) def get_virtual_pipeline_model_parallel_rank(): @@ -422,32 +450,36 @@ def get_tensor_model_parallel_src_rank(): def get_data_parallel_src_rank(): """Calculate the global rank corresponding to the first local rank in the data parallel group.""" - assert _DATA_PARALLEL_GLOBAL_RANKS is not None, \ - "Data parallel group is not initialized" + assert ( + _DATA_PARALLEL_GLOBAL_RANKS is not None + ), "Data parallel group is not initialized" return _DATA_PARALLEL_GLOBAL_RANKS[0] def get_pipeline_model_parallel_first_rank(): """Return the global rank of the first process in the pipeline for the current tensor parallel group""" - assert _PIPELINE_GLOBAL_RANKS is not None, \ - "Pipeline parallel group is not initialized" + assert ( + _PIPELINE_GLOBAL_RANKS is not None + ), "Pipeline parallel group is not initialized" return _PIPELINE_GLOBAL_RANKS[0] def get_pipeline_model_parallel_last_rank(): """Return the global rank of the last process in the pipeline for the current tensor parallel group""" - assert _PIPELINE_GLOBAL_RANKS is not None, \ - "Pipeline parallel group is not initialized" + assert ( + _PIPELINE_GLOBAL_RANKS is not None + ), "Pipeline parallel group is not initialized" last_rank_local = get_pipeline_model_parallel_world_size() - 1 return _PIPELINE_GLOBAL_RANKS[last_rank_local] def get_pipeline_model_parallel_next_rank(): """Return the global rank that follows the caller in the pipeline""" - assert _PIPELINE_GLOBAL_RANKS is not None, \ - "Pipeline parallel group is not initialized" + assert ( + _PIPELINE_GLOBAL_RANKS is not None + ), "Pipeline parallel group is not initialized" rank_in_pipeline = get_pipeline_model_parallel_rank() world_size = get_pipeline_model_parallel_world_size() return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] @@ -455,8 +487,9 @@ def get_pipeline_model_parallel_next_rank(): def get_pipeline_model_parallel_prev_rank(): """Return the global rank that preceeds the caller in the pipeline""" - assert _PIPELINE_GLOBAL_RANKS is not None, \ - "Pipeline parallel group is not initialized" + assert ( + _PIPELINE_GLOBAL_RANKS is not None + ), "Pipeline parallel group is not initialized" rank_in_pipeline = get_pipeline_model_parallel_rank() world_size = get_pipeline_model_parallel_world_size() return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] @@ -471,6 +504,7 @@ def get_data_parallel_rank(): """Return my rank for the data parallel group.""" return torch.distributed.get_rank(group=get_data_parallel_group()) + def destroy_model_parallel(): """Set the groups to none.""" global _MODEL_PARALLEL_GROUP diff --git a/sarathi/model_executor/parallel_utils/pipeline_parallel/mappings.py b/sarathi/model_executor/parallel_utils/pipeline_parallel/mappings.py index 6feebd9..99fb5de 100644 --- a/sarathi/model_executor/parallel_utils/pipeline_parallel/mappings.py +++ b/sarathi/model_executor/parallel_utils/pipeline_parallel/mappings.py @@ -1,12 +1,12 @@ import torch +from sarathi.metrics.constants import OperationMetrics +from sarathi.metrics.cuda_timer import CudaTimer from sarathi.model_executor.parallel_utils.parallel_state import ( get_pipeline_model_parallel_group, get_pipeline_model_parallel_next_rank, get_pipeline_model_parallel_prev_rank, ) -from sarathi.metrics.constants import OperationMetrics -from sarathi.metrics.cuda_timer import CudaTimer def send(hidden_states: torch.tensor): diff --git a/sarathi/model_executor/parallel_utils/tensor_parallel/__init__.py b/sarathi/model_executor/parallel_utils/tensor_parallel/__init__.py index d17f12f..8fad070 100644 --- a/sarathi/model_executor/parallel_utils/tensor_parallel/__init__.py +++ b/sarathi/model_executor/parallel_utils/tensor_parallel/__init__.py @@ -2,32 +2,24 @@ ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, - set_tensor_model_parallel_attributes, - set_defaults_if_not_set_tensor_model_parallel_attributes, copy_tensor_model_parallel_attributes, param_is_not_tensor_parallel_duplicate, + set_defaults_if_not_set_tensor_model_parallel_attributes, + set_tensor_model_parallel_attributes, ) - from .mappings import ( copy_to_tensor_model_parallel_region, - gather_from_tensor_model_parallel_region, gather_from_sequence_parallel_region, + gather_from_tensor_model_parallel_region, reduce_from_tensor_model_parallel_region, - scatter_to_tensor_model_parallel_region, scatter_to_sequence_parallel_region, + scatter_to_tensor_model_parallel_region, ) - -from .random import ( - get_cuda_rng_tracker, - model_parallel_cuda_manual_seed, -) - -from .utils import ( - split_tensor_along_last_dim, -) +from .random import get_cuda_rng_tracker, model_parallel_cuda_manual_seed +from .utils import split_tensor_along_last_dim __all__ = [ - #layers.py + # layers.py "ColumnParallelLinear", "RowParallelLinear", "VocabParallelEmbedding", diff --git a/sarathi/model_executor/parallel_utils/tensor_parallel/layers.py b/sarathi/model_executor/parallel_utils/tensor_parallel/layers.py index 014ee28..39c5a48 100644 --- a/sarathi/model_executor/parallel_utils/tensor_parallel/layers.py +++ b/sarathi/model_executor/parallel_utils/tensor_parallel/layers.py @@ -12,33 +12,33 @@ from torch.nn.parameter import Parameter from sarathi.logger import init_logger +from sarathi.metrics.cuda_timer import CudaTimer from sarathi.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from sarathi.metrics.cuda_timer import CudaTimer + from .mappings import ( gather_from_tensor_model_parallel_region, reduce_from_tensor_model_parallel_region, scatter_to_tensor_model_parallel_region, ) - -from .utils import ( - divide, - VocabUtility, -) +from .utils import VocabUtility, divide logger = init_logger(__name__) -_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False, - 'partition_dim': -1, - 'partition_stride': 1} +_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = { + "tensor_model_parallel": False, + "partition_dim": -1, + "partition_stride": 1, +} + def param_is_not_tensor_parallel_duplicate(param): - return (hasattr(param, 'tensor_model_parallel') and - param.tensor_model_parallel) or ( - get_tensor_model_parallel_rank() == 0) + return ( + hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel + ) or (get_tensor_model_parallel_rank() == 0) def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride): @@ -46,15 +46,16 @@ def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride): for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: assert not hasattr(tensor, attribute) # Set the attributes. - setattr(tensor, 'tensor_model_parallel', is_parallel) - setattr(tensor, 'partition_dim', dim) - setattr(tensor, 'partition_stride', stride) + setattr(tensor, "tensor_model_parallel", is_parallel) + setattr(tensor, "partition_dim", dim) + setattr(tensor, "partition_stride", stride) def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor): def maybe_set(attribute, value): if not hasattr(tensor, attribute): setattr(tensor, attribute, value) + for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute]) @@ -62,8 +63,8 @@ def maybe_set(attribute, value): def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor): def maybe_copy(attribute): if hasattr(source_tensor, attribute): - setattr(destination_tensor, attribute, - getattr(source_tensor, attribute)) + setattr(destination_tensor, attribute, getattr(source_tensor, attribute)) + for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: maybe_copy(attribute) @@ -84,16 +85,21 @@ class VocabParallelEmbedding(torch.nn.Module): perform_initialization """ - def __init__(self, num_embeddings: int, embedding_dim: int, *, - init_method=init.xavier_normal_, - params_dtype: torch.dtype=None, - use_cpu_initialization: bool=False, - perform_initialization: bool=False, - linear_metric_name:Optional[str]=None, - communication_metric_name:Optional[str]=None, - reduce_results: Optional[bool]=True, - world_size: Optional[int]=None, - rank: Optional[int]=None): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + *, + init_method=init.xavier_normal_, + params_dtype: torch.dtype = None, + use_cpu_initialization: bool = False, + perform_initialization: bool = False, + linear_metric_name: Optional[str] = None, + communication_metric_name: Optional[str] = None, + reduce_results: Optional[bool] = True, + world_size: Optional[int] = None, + rank: Optional[int] = None, + ): super(VocabParallelEmbedding, self).__init__() assert not perform_initialization assert not use_cpu_initialization @@ -107,24 +113,33 @@ def __init__(self, num_embeddings: int, embedding_dim: int, *, # Set the defaults for compatibility. self.padding_idx = None self.max_norm = None - self.norm_type = 2. + self.norm_type = 2.0 self.scale_grad_by_freq = False self.sparse = False self._weight = None - self.tensor_model_parallel_size = get_tensor_model_parallel_world_size() if world_size is None else world_size + self.tensor_model_parallel_size = ( + get_tensor_model_parallel_world_size() if world_size is None else world_size + ) self.rank = get_tensor_model_parallel_rank() if rank is None else rank self.reduce_results = reduce_results # Divide the weight matrix along the vocaburaly dimension. - self.vocab_start_index, self.vocab_end_index = \ + self.vocab_start_index, self.vocab_end_index = ( VocabUtility.vocab_range_from_global_vocab_size( - self.num_embeddings, self.rank, - self.tensor_model_parallel_size) - self.num_embeddings_per_partition = self.vocab_end_index - \ - self.vocab_start_index - - self.weight = Parameter(torch.empty( - self.num_embeddings_per_partition, self.embedding_dim, - device=torch.cuda.current_device(), dtype=params_dtype)) + self.num_embeddings, self.rank, self.tensor_model_parallel_size + ) + ) + self.num_embeddings_per_partition = ( + self.vocab_end_index - self.vocab_start_index + ) + + self.weight = Parameter( + torch.empty( + self.num_embeddings_per_partition, + self.embedding_dim, + device=torch.cuda.current_device(), + dtype=params_dtype, + ) + ) self._linear_timer = CudaTimer(linear_metric_name) self._communication_timer = CudaTimer(communication_metric_name) @@ -132,8 +147,9 @@ def __init__(self, num_embeddings: int, embedding_dim: int, *, def forward(self, input_): if self.tensor_model_parallel_size > 1: # Build the mask. - input_mask = (input_ < self.vocab_start_index) | \ - (input_ >= self.vocab_end_index) + input_mask = (input_ < self.vocab_start_index) | ( + input_ >= self.vocab_end_index + ) # Mask the input. masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0 @@ -141,11 +157,15 @@ def forward(self, input_): masked_input = input_ # Get the embeddings. with self._linear_timer: - output_parallel = F.embedding(masked_input, self.weight, - self.padding_idx, self.max_norm, - self.norm_type, - self.scale_grad_by_freq, - self.sparse) + output_parallel = F.embedding( + masked_input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) # Mask the output embedding. if self.tensor_model_parallel_size > 1: @@ -187,19 +207,25 @@ class ColumnParallelLinear(torch.nn.Module): use_cpu_initialization: """ - def __init__(self, input_size, output_size, *, - bias=True, gather_output=True, - init_method=init.xavier_normal_, stride=1, - keep_master_weight_for_test=False, - skip_bias_add=False, - params_dtype=None, - use_cpu_initialization=False, - perform_initialization=False, - linear_metric_name:Optional[str]=None, - communication_metric_name:Optional[str]=None, - world_size:Optional[int]=None, - layer_id:Optional[int]=None, - ): + def __init__( + self, + input_size, + output_size, + *, + bias=True, + gather_output=True, + init_method=init.xavier_normal_, + stride=1, + keep_master_weight_for_test=False, + skip_bias_add=False, + params_dtype=None, + use_cpu_initialization=False, + perform_initialization=False, + linear_metric_name: Optional[str] = None, + communication_metric_name: Optional[str] = None, + world_size: Optional[int] = None, + layer_id: Optional[int] = None, + ): super(ColumnParallelLinear, self).__init__() assert not perform_initialization assert not use_cpu_initialization @@ -209,7 +235,9 @@ def __init__(self, input_size, output_size, *, self.output_size = output_size self.gather_output = gather_output # Divide the weight matrix along the last dimension. - self.world_size = get_tensor_model_parallel_world_size() if world_size is None else world_size + self.world_size = ( + get_tensor_model_parallel_world_size() if world_size is None else world_size + ) self.output_size_per_partition = divide(output_size, self.world_size) self.skip_bias_add = skip_bias_add @@ -222,24 +250,34 @@ def __init__(self, input_size, output_size, *, self.create_weights(params_dtype) if bias: - self.bias = Parameter(torch.empty( - self.output_size_per_partition, - device=torch.cuda.current_device(), - dtype=params_dtype)) + self.bias = Parameter( + torch.empty( + self.output_size_per_partition, + device=torch.cuda.current_device(), + dtype=params_dtype, + ) + ) set_tensor_model_parallel_attributes(self.bias, True, 0, stride) # Always initialize bias to zero. with torch.no_grad(): self.bias.zero_() else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) self._linear_timer = CudaTimer(linear_metric_name, layer_id=layer_id) - self._communication_timer = CudaTimer(communication_metric_name, layer_id=layer_id) + self._communication_timer = CudaTimer( + communication_metric_name, layer_id=layer_id + ) def create_weights(self, dtype: torch.dtype) -> None: - self.weight = Parameter(torch.empty( - self.output_size_per_partition, self.input_size, - device=torch.cuda.current_device(), dtype=dtype)) + self.weight = Parameter( + torch.empty( + self.output_size_per_partition, + self.input_size, + device=torch.cuda.current_device(), + dtype=dtype, + ) + ) def apply_weights( self, @@ -310,20 +348,26 @@ class RowParallelLinear(torch.nn.Module): reduce_results: """ - def __init__(self, input_size, output_size, *, - bias=True, input_is_parallel=False, - init_method=init.xavier_normal_, stride=1, - keep_master_weight_for_test=False, - skip_bias_add=False, - params_dtype=None, - use_cpu_initialization=False, - perform_initialization=False, - reduce_results=True, - linear_metric_name:Optional[str]=None, - communication_metric_name:Optional[str]=None, - world_size:Optional[int]=None, - layer_id:Optional[int]=None, - ): + def __init__( + self, + input_size, + output_size, + *, + bias=True, + input_is_parallel=False, + init_method=init.xavier_normal_, + stride=1, + keep_master_weight_for_test=False, + skip_bias_add=False, + params_dtype=None, + use_cpu_initialization=False, + perform_initialization=False, + reduce_results=True, + linear_metric_name: Optional[str] = None, + communication_metric_name: Optional[str] = None, + world_size: Optional[int] = None, + layer_id: Optional[int] = None, + ): super(RowParallelLinear, self).__init__() assert not perform_initialization assert not use_cpu_initialization @@ -337,34 +381,49 @@ def __init__(self, input_size, output_size, *, params_dtype = torch.get_default_dtype() # Divide the weight matrix along the last dimension. - self.world_size = get_tensor_model_parallel_world_size() if world_size is None else world_size + self.world_size = ( + get_tensor_model_parallel_world_size() if world_size is None else world_size + ) self.input_size_per_partition = divide(input_size, self.world_size) self.skip_bias_add = skip_bias_add self.create_weights(params_dtype) if not reduce_results and (bias and not skip_bias_add): - logger.warning("When not reduce the results, adding bias to the " - "results can lead to incorrect results") + logger.warning( + "When not reduce the results, adding bias to the " + "results can lead to incorrect results" + ) if bias: - self.bias = Parameter(torch.empty( - self.output_size, device=torch.cuda.current_device(), - dtype=params_dtype)) + self.bias = Parameter( + torch.empty( + self.output_size, + device=torch.cuda.current_device(), + dtype=params_dtype, + ) + ) # Always initialize bias to zero. with torch.no_grad(): self.bias.zero_() else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) self._linear_timer = CudaTimer(linear_metric_name, layer_id=layer_id) - self._communication_timer = CudaTimer(communication_metric_name, layer_id=layer_id) - + self._communication_timer = CudaTimer( + communication_metric_name, layer_id=layer_id + ) + def create_weights(self, dtype: torch.dtype) -> None: - self.weight = Parameter(torch.empty( - self.output_size, self.input_size_per_partition, - device=torch.cuda.current_device(), dtype=dtype)) + self.weight = Parameter( + torch.empty( + self.output_size, + self.input_size_per_partition, + device=torch.cuda.current_device(), + dtype=dtype, + ) + ) def apply_weights(self, x: torch.Tensor) -> torch.Tensor: with self._linear_timer: diff --git a/sarathi/model_executor/parallel_utils/tensor_parallel/mappings.py b/sarathi/model_executor/parallel_utils/tensor_parallel/mappings.py index c1173cd..ea42000 100644 --- a/sarathi/model_executor/parallel_utils/tensor_parallel/mappings.py +++ b/sarathi/model_executor/parallel_utils/tensor_parallel/mappings.py @@ -5,10 +5,11 @@ import torch from sarathi.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - get_tensor_model_parallel_group, ) + from .utils import split_tensor_along_last_dim @@ -16,7 +17,7 @@ def _reduce(input_): """All-reduce the input tensor across model parallel group.""" # Bypass the function if we are using only 1 GPU. - if get_tensor_model_parallel_world_size()==1: + if get_tensor_model_parallel_world_size() == 1: return input_ # All-reduce. @@ -55,13 +56,14 @@ def _split_along_first_dim(input_): # Split along first dimension. dim_size = input_.size()[0] - assert dim_size % world_size == 0, \ - "First dimension of the tensor should be divisible by tensor parallel size" + assert ( + dim_size % world_size == 0 + ), "First dimension of the tensor should be divisible by tensor parallel size" local_dim_size = dim_size // world_size rank = get_tensor_model_parallel_rank() dim_offset = rank * local_dim_size - output = input_[dim_offset:dim_offset+local_dim_size].contiguous() + output = input_[dim_offset : dim_offset + local_dim_size].contiguous() return output @@ -80,7 +82,9 @@ def _gather_along_last_dim(input_): tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list[rank] = input_ - torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group()) + torch.distributed.all_gather( + tensor_list, input_, group=get_tensor_model_parallel_group() + ) # Note: torch.cat already creates a contiguous tensor. output = torch.cat(tensor_list, dim=last_dim).contiguous() @@ -99,13 +103,16 @@ def _gather_along_first_dim(input_): dim_size = list(input_.size()) dim_size[0] = dim_size[0] * world_size - output = torch.empty(dim_size, dtype=input_.dtype, - device=torch.cuda.current_device()) - torch.distributed._all_gather_base(output, input_.contiguous(), - group=get_tensor_model_parallel_group()) + output = torch.empty( + dim_size, dtype=input_.dtype, device=torch.cuda.current_device() + ) + torch.distributed._all_gather_base( + output, input_.contiguous(), group=get_tensor_model_parallel_group() + ) return output + def _reduce_scatter_along_first_dim(input_): """Reduce-scatter the input tensor across model parallel group.""" world_size = get_tensor_model_parallel_world_size() @@ -114,15 +121,18 @@ def _reduce_scatter_along_first_dim(input_): return input_ dim_size = list(input_.size()) - assert dim_size[0] % world_size == 0, \ - "First dimension of the tensor should be divisible by tensor parallel size" + assert ( + dim_size[0] % world_size == 0 + ), "First dimension of the tensor should be divisible by tensor parallel size" dim_size[0] = dim_size[0] // world_size - output = torch.empty(dim_size, dtype=input_.dtype, - device=torch.cuda.current_device()) - torch.distributed._reduce_scatter_base(output, input_.contiguous(), - group=get_tensor_model_parallel_group()) + output = torch.empty( + dim_size, dtype=input_.dtype, device=torch.cuda.current_device() + ) + torch.distributed._reduce_scatter_base( + output, input_.contiguous(), group=get_tensor_model_parallel_group() + ) return output @@ -252,6 +262,7 @@ def backward(ctx, grad_output): # Helper functions. # ----------------- + def copy_to_tensor_model_parallel_region(input_): return _CopyToModelParallelRegion.apply(input_) @@ -278,4 +289,3 @@ def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=Tru def reduce_scatter_to_sequence_parallel_region(input_): return _ReduceScatterToSequenceParallelRegion.apply(input_) - diff --git a/sarathi/model_executor/parallel_utils/tensor_parallel/random.py b/sarathi/model_executor/parallel_utils/tensor_parallel/random.py index 847ebdf..c0f9d79 100644 --- a/sarathi/model_executor/parallel_utils/tensor_parallel/random.py +++ b/sarathi/model_executor/parallel_utils/tensor_parallel/random.py @@ -9,14 +9,15 @@ import torch from torch import _C -from torch.cuda import _lazy_call, device as device_ctx_manager +from torch.cuda import _lazy_call +from torch.cuda import device as device_ctx_manager from sarathi.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, ) # Default name for the model parallel rng tracker. -_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' +_MODEL_PARALLEL_RNG_TRACKER_NAME = "model-parallel-rng" def _set_cuda_rng_state(new_state, device=-1): @@ -28,19 +29,20 @@ def _set_cuda_rng_state(new_state, device=-1): with a single change: the input state is not cloned. Cloning caused major performance issues for +4 GPU cases. """ - if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState): + if hasattr(_C, "_cuda_setRNGState") and callable(_C._cuda_setRNGState): # older PyTorch def cb(): with device_ctx_manager(device): _C._cuda_setRNGState(new_state) + else: # newer PyTorch if device == -1: - device = torch.device('cuda') + device = torch.device("cuda") elif isinstance(device, str): device = torch.device(device) elif isinstance(device, int): - device = torch.device('cuda', device) + device = torch.device("cuda", device) def cb(): idx = device.index @@ -52,7 +54,6 @@ def cb(): _lazy_call(cb) - class CudaRNGStatesTracker: """Tracker for the cuda RNG states. @@ -90,11 +91,11 @@ def add(self, name, seed): """Track the rng state.""" # Check seed is not already used. if seed in self.seeds_: - raise Exception('seed {} already exists'.format(seed)) + raise Exception("seed {} already exists".format(seed)) self.seeds_.add(seed) # Check that state is not already defined. if name in self.states_: - raise Exception('cuda rng state {} already exists'.format(name)) + raise Exception("cuda rng state {} already exists".format(name)) # Get the current rng state. orig_rng_state = torch.cuda.get_rng_state() # Set the new state and store it. @@ -109,7 +110,7 @@ def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): the original state.""" # Check if we have added the state if name not in self.states_: - raise Exception('cuda rng state {} is not added'.format(name)) + raise Exception("cuda rng state {} is not added".format(name)) # Store current rng state. orig_cuda_rng_state = torch.cuda.get_rng_state() # Set rng state to the desired one @@ -160,5 +161,6 @@ def model_parallel_cuda_manual_seed(seed): # Set the default state. torch.cuda.manual_seed(data_parallel_seed) # and model parallel state. - _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, - tensor_model_parallel_seed) + _CUDA_RNG_STATE_TRACKER.add( + _MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed + ) diff --git a/sarathi/model_executor/parallel_utils/tensor_parallel/utils.py b/sarathi/model_executor/parallel_utils/tensor_parallel/utils.py index d892f77..d8b5f20 100644 --- a/sarathi/model_executor/parallel_utils/tensor_parallel/utils.py +++ b/sarathi/model_executor/parallel_utils/tensor_parallel/utils.py @@ -2,9 +2,11 @@ # Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -import torch from typing import List, Sequence +import torch + + def ensure_divisibility(numerator, denominator): """Ensure that numerator is divisible by the denominator.""" assert numerator % denominator == 0, "{} is not divisible by {}".format( @@ -24,16 +26,16 @@ def split_tensor_along_last_dim( num_partitions: int, contiguous_split_chunks: bool = False, ) -> List[torch.Tensor]: - """ Split a tensor along its last dimension. + """Split a tensor along its last dimension. - Arguments: - tensor: input tensor. - num_partitions: number of partitions to split the tensor - contiguous_split_chunks: If True, make each chunk contiguous - in memory. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. - Returns: - A list of Tensors + Returns: + A list of Tensors """ # Get the size and dimension. last_dim = tensor.dim() - 1 @@ -48,9 +50,9 @@ def split_tensor_along_last_dim( class VocabUtility: - """ Split the vocabulary into `world_size` chunks and return the first - and last index of the vocabulary belonging to the `rank` - partition: Note that indices in [fist, last) + """Split the vocabulary into `world_size` chunks and return the first + and last index of the vocabulary belonging to the `rank` + partition: Note that indices in [fist, last) """ @@ -63,7 +65,9 @@ def vocab_range_from_per_partition_vocab_size( return index_f, index_l @staticmethod - def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int) -> Sequence[int]: + def vocab_range_from_global_vocab_size( + global_vocab_size: int, rank: int, world_size: int + ) -> Sequence[int]: per_partition_vocab_size = divide(global_vocab_size, world_size) return VocabUtility.vocab_range_from_per_partition_vocab_size( per_partition_vocab_size, rank, world_size diff --git a/sarathi/model_executor/utils.py b/sarathi/model_executor/utils.py index a6a53e3..8887fca 100644 --- a/sarathi/model_executor/utils.py +++ b/sarathi/model_executor/utils.py @@ -1,12 +1,17 @@ """Utils for model executor.""" + import random from typing import List import numpy as np import torch -from sarathi.model_executor.parallel_utils.parallel_state import model_parallel_is_initialized -from sarathi.model_executor.parallel_utils.tensor_parallel import model_parallel_cuda_manual_seed +from sarathi.model_executor.parallel_utils.parallel_state import ( + model_parallel_is_initialized, +) +from sarathi.model_executor.parallel_utils.tensor_parallel import ( + model_parallel_cuda_manual_seed, +) def set_random_seed(seed: int) -> None: diff --git a/sarathi/model_executor/weight_utils.py b/sarathi/model_executor/weight_utils.py index a1d3d70..43339a5 100644 --- a/sarathi/model_executor/weight_utils.py +++ b/sarathi/model_executor/weight_utils.py @@ -1,15 +1,16 @@ """Utilities for downloading and initializing model weights.""" -import filelock + import glob import json import os from collections import defaultdict from typing import Any, Iterator, List, Optional, Tuple -from huggingface_hub import snapshot_download -from safetensors.torch import load_file, save_file, safe_open +import filelock import numpy as np import torch +from huggingface_hub import snapshot_download +from safetensors.torch import load_file, safe_open, save_file from tqdm.auto import tqdm from sarathi.logger import init_logger @@ -64,10 +65,12 @@ def convert_bin_to_safetensor_file( sf_size = os.stat(sf_filename).st_size pt_size = os.stat(pt_filename).st_size if (sf_size - pt_size) / pt_size > 0.01: - raise RuntimeError(f"""The file size different is more than 1%: + raise RuntimeError( + f"""The file size different is more than 1%: - {sf_filename}: {sf_size} - {pt_filename}: {pt_size} - """) + """ + ) # check if the tensors are the same reloaded = load_file(sf_filename) @@ -96,11 +99,13 @@ def prepare_hf_model_weights( # Use file lock to prevent multiple processes from # downloading the same model weights at the same time. with get_lock(model_name_or_path, cache_dir): - hf_folder = snapshot_download(model_name_or_path, - allow_patterns=allow_patterns, - cache_dir=cache_dir, - tqdm_class=Disabledtqdm, - revision=revision) + hf_folder = snapshot_download( + model_name_or_path, + allow_patterns=allow_patterns, + cache_dir=cache_dir, + tqdm_class=Disabledtqdm, + revision=revision, + ) else: hf_folder = model_name_or_path hf_weights_files: List[str] = [] @@ -112,15 +117,16 @@ def prepare_hf_model_weights( ] if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt: - return prepare_hf_model_weights(model_name_or_path, - cache_dir=cache_dir, - use_safetensors=False, - fall_back_to_pt=False, - revision=revision) + return prepare_hf_model_weights( + model_name_or_path, + cache_dir=cache_dir, + use_safetensors=False, + fall_back_to_pt=False, + revision=revision, + ) if len(hf_weights_files) == 0: - raise RuntimeError( - f"Cannot find any model weights with `{model_name_or_path}`") + raise RuntimeError(f"Cannot find any model weights with `{model_name_or_path}`") return hf_folder, hf_weights_files, use_safetensors @@ -151,7 +157,8 @@ def hf_model_weights_iterator( cache_dir=cache_dir, use_safetensors=use_safetensors, fall_back_to_pt=fall_back_to_pt, - revision=revision) + revision=revision, + ) if use_np_cache: # Currently np_cache only support *.bin checkpoints @@ -225,7 +232,7 @@ def load_padded_tensor_parallel_vocab( end_idx = (tensor_model_parallel_rank + 1) * shard_size loaded_weight = loaded_weight[start_idx:end_idx] loaded_weight = convert_pyslice_to_tensor(loaded_weight) - param[:loaded_weight.shape[0]].copy_(loaded_weight) + param[: loaded_weight.shape[0]].copy_(loaded_weight) def load_tensor_parallel_weights( @@ -254,7 +261,8 @@ def load_tensor_parallel_weights( loaded_weight = convert_pyslice_to_tensor(loaded_weight) assert param.shape == loaded_weight.shape, ( f"{param_name} shape mismatch between model and checkpoint: " - f"{param.shape} != {loaded_weight.shape}") + f"{param.shape} != {loaded_weight.shape}" + ) param.data.copy_(loaded_weight) diff --git a/sarathi/transformers_utils/config.py b/sarathi/transformers_utils/config.py index 188168a..9713f46 100644 --- a/sarathi/transformers_utils/config.py +++ b/sarathi/transformers_utils/config.py @@ -12,20 +12,24 @@ } -def get_config(model: str, - trust_remote_code: bool, - revision: Optional[str] = None) -> PretrainedConfig: +def get_config( + model: str, trust_remote_code: bool, revision: Optional[str] = None +) -> PretrainedConfig: try: config = AutoConfig.from_pretrained( - model, trust_remote_code=trust_remote_code, revision=revision) + model, trust_remote_code=trust_remote_code, revision=revision + ) except ValueError as e: - if (not trust_remote_code and - "requires you to execute the configuration file" in str(e)): + if ( + not trust_remote_code + and "requires you to execute the configuration file" in str(e) + ): err_msg = ( "Failed to load the model config. If the model is a custom " "model not yet available in the HuggingFace transformers " "library, consider setting `trust_remote_code=True` in LLM " - "or using the `--trust-remote-code` flag in the CLI.") + "or using the `--trust-remote-code` flag in the CLI." + ) raise RuntimeError(err_msg) from e else: raise e diff --git a/sarathi/transformers_utils/configs/__init__.py b/sarathi/transformers_utils/configs/__init__.py index 5276baa..e19cf31 100644 --- a/sarathi/transformers_utils/configs/__init__.py +++ b/sarathi/transformers_utils/configs/__init__.py @@ -1,8 +1,8 @@ -from sarathi.transformers_utils.configs.qwen import QWenConfig # RWConfig is for the original tiiuae/falcon-40b(-instruct) and # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. from sarathi.transformers_utils.configs.falcon import RWConfig +from sarathi.transformers_utils.configs.qwen import QWenConfig from sarathi.transformers_utils.configs.yi import YiConfig __all__ = [ diff --git a/sarathi/transformers_utils/configs/falcon.py b/sarathi/transformers_utils/configs/falcon.py index 6d68ef4..6915e46 100644 --- a/sarathi/transformers_utils/configs/falcon.py +++ b/sarathi/transformers_utils/configs/falcon.py @@ -74,9 +74,7 @@ def __init__( # Hack for falcon-40b self.new_decoder_architecture = True - super().__init__(bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - **kwargs) + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) @property def head_dim(self): diff --git a/sarathi/transformers_utils/configs/yi.py b/sarathi/transformers_utils/configs/yi.py index 359922e..ea71d8c 100644 --- a/sarathi/transformers_utils/configs/yi.py +++ b/sarathi/transformers_utils/configs/yi.py @@ -1,4 +1,5 @@ """ Yi model configuration""" + from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging @@ -9,9 +10,10 @@ class YiConfig(PretrainedConfig): r""" - Reference: - https://huggingface.co/01-ai/Yi-6B/blob/main/configuration_yi.py + Reference: + https://huggingface.co/01-ai/Yi-6B/blob/main/configuration_yi.py """ + model_type = "Yi" keys_to_ignore_at_inference = ["past_key_values"] diff --git a/sarathi/transformers_utils/tokenizer.py b/sarathi/transformers_utils/tokenizer.py index 6bbc8ca..4eeaeb0 100644 --- a/sarathi/transformers_utils/tokenizer.py +++ b/sarathi/transformers_utils/tokenizer.py @@ -1,7 +1,6 @@ from typing import List, Optional, Tuple, Union -from transformers import (AutoTokenizer, PreTrainedTokenizer, - PreTrainedTokenizerFast) +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from sarathi.logger import init_logger @@ -18,31 +17,30 @@ def get_tokenizer( """Gets a tokenizer for the given model name via Huggingface.""" if tokenizer_mode == "slow": if kwargs.get("use_fast", False): - raise ValueError( - "Cannot use the fast tokenizer in slow tokenizer mode.") + raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") kwargs["use_fast"] = False try: tokenizer = AutoTokenizer.from_pretrained( - tokenizer_name, - *args, - trust_remote_code=trust_remote_code, - **kwargs) + tokenizer_name, *args, trust_remote_code=trust_remote_code, **kwargs + ) except TypeError as e: # The LLaMA tokenizer causes a protobuf error in some environments. - err_msg = ("Failed to load the tokenizer.") + err_msg = "Failed to load the tokenizer." raise RuntimeError(err_msg) from e except ValueError as e: # If the error pertains to the tokenizer class not existing or not # currently being imported, suggest using the --trust-remote-code flag. - if (not trust_remote_code and - ("does not exist or is not currently imported." in str(e) - or "requires you to execute the tokenizer file" in str(e))): + if not trust_remote_code and ( + "does not exist or is not currently imported." in str(e) + or "requires you to execute the tokenizer file" in str(e) + ): err_msg = ( "Failed to load the tokenizer. If the tokenizer is a custom " "tokenizer not yet available in the HuggingFace transformers " "library, consider setting `trust_remote_code=True` in LLM " - "or using the `--trust-remote-code` flag in the CLI.") + "or using the `--trust-remote-code` flag in the CLI." + ) raise RuntimeError(err_msg) from e else: raise e @@ -50,7 +48,8 @@ def get_tokenizer( if not isinstance(tokenizer, PreTrainedTokenizerFast): logger.warning( "Using a slow tokenizer. This might cause a significant " - "slowdown. Consider using a fast tokenizer instead.") + "slowdown. Consider using a fast tokenizer instead." + ) return tokenizer @@ -101,7 +100,8 @@ def detokenize_incrementally( if prev_tokens is None: try: new_tokens = tokenizer.convert_ids_to_tokens( - all_input_ids[-6:], skip_special_tokens=skip_special_tokens) + all_input_ids[-6:], skip_special_tokens=skip_special_tokens + ) except ValueError as e: new_tokens = ["[UNK]"] * 6 logger.warning(f"Warning: {e}", flush=True) @@ -116,7 +116,8 @@ def detokenize_incrementally( # Put new_token_id in a list so skip_special_tokens is respected try: new_tokens = tokenizer.convert_ids_to_tokens( - [new_token_id], skip_special_tokens=skip_special_tokens) + [new_token_id], skip_special_tokens=skip_special_tokens + ) except ValueError as e: new_tokens = [prev_tokens[-1]] logger.warning(f"Warning: {e}", flush=True) @@ -127,25 +128,27 @@ def detokenize_incrementally( # surrounding ids. if tokenizer.is_fast or not tokenizer.get_added_vocab(): prefix_text = tokenizer.convert_tokens_to_string( - output_tokens[prefix_offset:read_offset]) - new_text = tokenizer.convert_tokens_to_string( - output_tokens[prefix_offset:]) + output_tokens[prefix_offset:read_offset] + ) + new_text = tokenizer.convert_tokens_to_string(output_tokens[prefix_offset:]) else: prefix_text = _convert_tokens_to_string_with_added_encoders( tokenizer, output_tokens[prefix_offset:read_offset], - skip_special_tokens=skip_special_tokens) + skip_special_tokens=skip_special_tokens, + ) new_text = _convert_tokens_to_string_with_added_encoders( tokenizer, output_tokens[prefix_offset:], - skip_special_tokens=skip_special_tokens) + skip_special_tokens=skip_special_tokens, + ) if len(new_text) > len(prefix_text) and not new_text.endswith("�"): # utf-8 char at the end means it's a potential unfinished byte sequence # from byte fallback tokenization. # If it's in the middle, it's probably a real invalid id generated # by the model - new_text = new_text[len(prefix_text):] + new_text = new_text[len(prefix_text) :] return new_tokens, new_text, read_offset, len(output_tokens) else: return new_tokens, "", prefix_offset, read_offset diff --git a/sarathi/utils/__init__.py b/sarathi/utils/__init__.py index 62723e5..f65f3f8 100644 --- a/sarathi/utils/__init__.py +++ b/sarathi/utils/__init__.py @@ -62,7 +62,7 @@ def unset_cuda_visible_devices() -> None: def is_port_in_use(port: int) -> bool: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - return s.connect_ex(('localhost', port)) == 0 + return s.connect_ex(("localhost", port)) == 0 def get_random_port() -> int: diff --git a/sarathi/utils/singleton.py b/sarathi/utils/singleton.py index 8825a21..39ffc72 100644 --- a/sarathi/utils/singleton.py +++ b/sarathi/utils/singleton.py @@ -1,4 +1,4 @@ -""" +""" Singleton metaclass as described in https://stackoverflow.com/questions/6760685/creating-a-singleton-in-python """ @@ -9,6 +9,5 @@ class Singleton(type): def __call__(cls, *args, **kwargs): if cls not in cls._instances: - cls._instances[cls] = super(Singleton, - cls).__call__(*args, **kwargs) + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) return cls._instances[cls] diff --git a/sarathi/utils/threading_utils.py b/sarathi/utils/threading_utils.py index 7f2f518..41fbcea 100644 --- a/sarathi/utils/threading_utils.py +++ b/sarathi/utils/threading_utils.py @@ -1,14 +1,15 @@ import os import traceback -from threading import Lock from functools import wraps +from threading import Lock def synchronized(method): - """ Synchronization decorator at the instance level. """ + """Synchronization decorator at the instance level.""" @wraps(method) def synced_method(self, *args, **kwargs): + # pylint: disable=protected-access if not hasattr(self, "_lock"): self._lock = Lock() @@ -24,8 +25,8 @@ def exit_on_error(func): def wrapper(*args, **kwargs): try: return func(*args, **kwargs) - except Exception: + except Exception: # pylint: disable=broad-except traceback.print_exc() - os._exit(1) + os._exit(1) # pylint: disable=protected-access return wrapper diff --git a/sarathi/worker/base_worker.py b/sarathi/worker/base_worker.py index 7961ab6..afe39de 100644 --- a/sarathi/worker/base_worker.py +++ b/sarathi/worker/base_worker.py @@ -1,28 +1,34 @@ """A GPU worker class.""" + import os import time -from typing import Tuple, Optional +from typing import Optional, Tuple import torch import torch.distributed -from sarathi.config import (CacheConfig, ModelConfig, ParallelConfig, - MetricsConfig, BaseSchedulerConfig) +from sarathi.config import ( + BaseSchedulerConfig, + CacheConfig, + MetricsConfig, + ModelConfig, + ParallelConfig, +) +from sarathi.core.datatypes.scheduler_output import SchedulerOutputs +from sarathi.core.datatypes.sequence import SamplerOutputs, Sequence +from sarathi.core.sequence_manager.worker_sequence_manager import WorkerSequenceManager +from sarathi.logger import init_logger +from sarathi.metrics.metrics_store import MetricsStore from sarathi.model_executor import set_random_seed +from sarathi.model_executor.attention import set_attention_backend +from sarathi.model_executor.model_runner import ModelRunner from sarathi.model_executor.parallel_utils.parallel_state import ( - initialize_model_parallel, - get_tensor_model_parallel_rank, get_pipeline_model_parallel_rank, + get_tensor_model_parallel_rank, + initialize_model_parallel, ) -from sarathi.core.datatypes.sequence import SamplerOutputs, Sequence -from sarathi.worker.cache_engine import CacheEngine -from sarathi.metrics.metrics_store import MetricsStore from sarathi.utils.threading_utils import synchronized -from sarathi.model_executor.model_runner import ModelRunner -from sarathi.logger import init_logger -from sarathi.model_executor.attention import set_attention_backend -from sarathi.core.sequence_manager.worker_sequence_manager import WorkerSequenceManager -from sarathi.core.datatypes.scheduler_output import SchedulerOutputs +from sarathi.worker.cache_engine import CacheEngine logger = init_logger(__name__) @@ -91,28 +97,36 @@ def init_model(self): torch.cuda.set_device(self.device) # Initialize the distributed environment. - _init_distributed_environment(self.parallel_config, self.rank, - self.distributed_init_method) + _init_distributed_environment( + self.parallel_config, self.rank, self.distributed_init_method + ) self.tensor_model_parallel_rank = get_tensor_model_parallel_rank() self.pipeline_model_parallel_rank = get_pipeline_model_parallel_rank() self.is_tensor_parallel_rank_zero = self.tensor_model_parallel_rank == 0 self.is_first_pipeline_stage = self.pipeline_model_parallel_rank == 0 - self.is_last_pipeline_stage = self.pipeline_model_parallel_rank == self.parallel_config.pipeline_parallel_size - 1 + self.is_last_pipeline_stage = ( + self.pipeline_model_parallel_rank + == self.parallel_config.pipeline_parallel_size - 1 + ) logger.info( f"Initializing worker {self.rank} on device {self.device}, " f"tensor parallel rank {self.tensor_model_parallel_rank} " - f"and pipeline parallel rank {self.pipeline_model_parallel_rank}.") + f"and pipeline parallel rank {self.pipeline_model_parallel_rank}." + ) # Initialize the model. set_random_seed(self.model_config.seed) - self.model_runner = ModelRunner(self.model_config, - self.parallel_config, - self.scheduler_config, - self.cache_config, self.device, - self.rank) + self.model_runner = ModelRunner( + self.model_config, + self.parallel_config, + self.scheduler_config, + self.cache_config, + self.device, + self.rank, + ) logger.info(f"Model initialized on worker {self.rank}.") @torch.inference_mode() @@ -122,8 +136,9 @@ def init_cache_engine(self, cache_config: CacheConfig) -> None: self.cache_config = cache_config - self.cache_engine = CacheEngine(self.cache_config, self.model_config, - self.parallel_config) + self.cache_engine = CacheEngine( + self.cache_config, self.model_config, self.parallel_config + ) self.gpu_cache = self.cache_engine.gpu_cache self.seq_manager = WorkerSequenceManager( @@ -139,8 +154,9 @@ def add_seq(self, seq: Sequence) -> None: def get_model_parallel_ranks(self) -> Tuple[int, int]: return self.tensor_model_parallel_rank, self.pipeline_model_parallel_rank - def on_step_completed(self, scheduler_outputs: SchedulerOutputs, - sampler_outputs: SamplerOutputs) -> None: + def on_step_completed( + self, scheduler_outputs: SchedulerOutputs, sampler_outputs: SamplerOutputs + ) -> None: self.seq_manager.on_step_completed(scheduler_outputs, sampler_outputs) @torch.inference_mode() @@ -186,10 +202,12 @@ def reset_metrics(self) -> None: @synchronized def start_profiling(self) -> None: - self.profiler = torch.profiler.profile(activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], ) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + ) self.profiler.__enter__() @synchronized @@ -199,7 +217,8 @@ def profile_num_available_blocks( gpu_memory_utilization: float, ) -> Tuple[int, int]: return self.model_runner.profile_num_available_blocks( - block_size, gpu_memory_utilization) + block_size, gpu_memory_utilization + ) @synchronized def stop_profiling(self) -> None: @@ -221,11 +240,13 @@ def _init_distributed_environment( raise RuntimeError( "torch.distributed is already initialized but the torch world " "size does not match parallel_config.world_size " - f"({torch_world_size} vs. {parallel_config.world_size}).") + f"({torch_world_size} vs. {parallel_config.world_size})." + ) elif not distributed_init_method: raise ValueError( "distributed_init_method must be set if torch.distributed " - "is not already initialized") + "is not already initialized" + ) else: torch.distributed.init_process_group( backend="nccl", @@ -236,5 +257,6 @@ def _init_distributed_environment( # A small all_reduce for warmup. torch.distributed.all_reduce(torch.zeros(1).cuda()) - initialize_model_parallel(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) + initialize_model_parallel( + parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size + ) diff --git a/sarathi/worker/cache_engine.py b/sarathi/worker/cache_engine.py index 0b51d91..d839248 100644 --- a/sarathi/worker/cache_engine.py +++ b/sarathi/worker/cache_engine.py @@ -1,12 +1,13 @@ """CacheEngine class for managing the KV cache.""" + from typing import List, Tuple, Union import torch from sarathi.config import CacheConfig, ModelConfig, ParallelConfig from sarathi.logger import init_logger -from sarathi.utils import in_wsl from sarathi.model_executor.attention import get_attention_wrapper +from sarathi.utils import in_wsl logger = init_logger(__name__) @@ -45,7 +46,8 @@ def allocate_gpu_cache(self) -> List[torch.Tensor]: for _ in range(self.num_layers): gpu_blocks = get_attention_wrapper().get_cache_block( - self.num_gpu_blocks, dtype=self.dtype, device="cuda") + self.num_gpu_blocks, dtype=self.dtype, device="cuda" + ) gpu_cache.append(gpu_blocks) return gpu_cache diff --git a/sarathi/worker/pipeline_parallel_worker.py b/sarathi/worker/pipeline_parallel_worker.py index a5a3856..5475f82 100644 --- a/sarathi/worker/pipeline_parallel_worker.py +++ b/sarathi/worker/pipeline_parallel_worker.py @@ -1,17 +1,23 @@ """A GPU worker class.""" -from threading import Thread -from typing import Tuple, Optional + from queue import Queue +from threading import Thread +from typing import Optional, Tuple import torch import torch.distributed -from sarathi.config import (CacheConfig, ModelConfig, ParallelConfig, - MetricsConfig, BaseSchedulerConfig) +from sarathi.config import ( + BaseSchedulerConfig, + CacheConfig, + MetricsConfig, + ModelConfig, + ParallelConfig, +) +from sarathi.core.datatypes.scheduler_output import SchedulerOutputs from sarathi.core.datatypes.sequence import SamplerOutputs -from sarathi.utils.threading_utils import synchronized, exit_on_error from sarathi.logger import init_logger -from sarathi.core.datatypes.scheduler_output import SchedulerOutputs +from sarathi.utils.threading_utils import exit_on_error, synchronized from sarathi.worker.base_worker import BaseWorker logger = init_logger(__name__) @@ -36,13 +42,19 @@ def __init__( rank: Optional[int] = None, distributed_init_method: Optional[str] = None, ) -> None: - super().__init__(model_config, parallel_config, scheduler_config, - cache_config, metrics_config, local_rank, rank, - distributed_init_method) + super().__init__( + model_config, + parallel_config, + scheduler_config, + cache_config, + metrics_config, + local_rank, + rank, + distributed_init_method, + ) self.execution_queue = Queue() self.output_queue = Queue() - self.execution_thread = Thread(target=self._execution_loop, - daemon=True) + self.execution_thread = Thread(target=self._execution_loop, daemon=True) def _verify_parallel_config(self) -> None: assert self.parallel_config.pipeline_parallel_size > 1 @@ -57,15 +69,17 @@ def enqueue( ) -> None: self.execution_queue.put(scheduler_outputs) - def on_step_completed(self, scheduler_outputs: SchedulerOutputs, - sampler_outputs: SamplerOutputs) -> None: + def on_step_completed( + self, scheduler_outputs: SchedulerOutputs, sampler_outputs: SamplerOutputs + ) -> None: # in pipeline parallel case, each stage won't have sampler output # so we need to do the book keeping update later pass @synchronized - def on_sampling_completed(self, scheduler_outputs: SchedulerOutputs, - sampler_outputs: SamplerOutputs) -> None: + def on_sampling_completed( + self, scheduler_outputs: SchedulerOutputs, sampler_outputs: SamplerOutputs + ) -> None: self.seq_manager.on_step_completed(scheduler_outputs, sampler_outputs) @exit_on_error