Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLM.int8() Refactoring: Part 1 #1401

Merged
merged 72 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
0cc5c95
Start of int8 refactor: remove col32/col_ampere/col_turing transforms…
matthewdouglas Oct 7, 2024
0f2dc34
Fix unintended change
matthewdouglas Oct 8, 2024
50fe50e
New naive mm_dequant kernel for row-major; cleanup
matthewdouglas Oct 9, 2024
57e6427
fix
matthewdouglas Oct 9, 2024
ca372f2
int8 refactor: initial sparse decomp, cleanup
matthewdouglas Oct 14, 2024
510a880
Int8 refactoring: remove separate NO_CUBLASLT build; more cleanup
matthewdouglas Oct 14, 2024
0ab14fe
int8: inference optimizations, some cleanup
matthewdouglas Oct 18, 2024
fdf4745
int8: more tests passing, cleanup
matthewdouglas Oct 18, 2024
d231db7
int8 - more cleanup, most tests passing
matthewdouglas Oct 21, 2024
dfc4668
int8: specify CUDA stream for int8 ops
matthewdouglas Oct 22, 2024
01bf54e
perf: reduce overhead from getting cudaStream ptr
matthewdouglas Oct 24, 2024
32979b4
Mark some functions for deprecation.
matthewdouglas Oct 24, 2024
521da0c
int8 sparse decomp: small perf improvement
matthewdouglas Oct 30, 2024
c75eecd
Merge branch 'main' into int8
matthewdouglas Oct 30, 2024
217cf8e
update setup.py
matthewdouglas Oct 30, 2024
b9cb5c9
Update bitsandbytes/autograd/_functions.py
matthewdouglas Oct 31, 2024
6fa7905
Update bitsandbytes/functional.py
matthewdouglas Oct 31, 2024
e929df0
Update bitsandbytes/functional.py
matthewdouglas Oct 31, 2024
c7b31df
Update bitsandbytes/research/autograd/_functions.py
matthewdouglas Oct 31, 2024
57300e7
int8 - perf improvement for sparse decomposition inference; deprecate…
matthewdouglas Nov 4, 2024
0460d2e
int8 cleanup
matthewdouglas Nov 4, 2024
437a17e
Merge branch 'int8' of https://github.com/TimDettmers/bitsandbytes in…
matthewdouglas Nov 4, 2024
762daf4
Ignore ruff rule ISC001 (incompatible with formatter)
matthewdouglas Nov 4, 2024
875414e
add comment
matthewdouglas Nov 4, 2024
0aefeb0
int8 more cleanup
matthewdouglas Nov 4, 2024
bfb42d1
Update bitsandbytes/functional.py
matthewdouglas Nov 4, 2024
1ae2476
Merge branch 'int8' of https://github.com/TimDettmers/bitsandbytes in…
matthewdouglas Nov 4, 2024
bf002db
int8: rename / deprecate old fn signatures
matthewdouglas Nov 4, 2024
7f6fb60
Update bitsandbytes/functional.py
matthewdouglas Nov 4, 2024
135b336
type annotation
matthewdouglas Nov 4, 2024
5388877
format update
matthewdouglas Nov 4, 2024
be2e98f
Update bitsandbytes/research/autograd/_functions.py
matthewdouglas Nov 4, 2024
4c849bb
cleanup
matthewdouglas Nov 4, 2024
32a60c5
Merge branch 'int8' of https://github.com/TimDettmers/bitsandbytes in…
matthewdouglas Nov 4, 2024
b954474
Add comment to explain division optimization
matthewdouglas Nov 4, 2024
35dbb2e
more cleanup
matthewdouglas Nov 4, 2024
b36003f
Update bitsandbytes/functional.py
matthewdouglas Nov 5, 2024
03a1963
Update bitsandbytes/functional.py
matthewdouglas Nov 5, 2024
980279f
Update bitsandbytes/functional.py
matthewdouglas Nov 5, 2024
a72c463
cleanup
matthewdouglas Nov 5, 2024
b5d6135
Merge branch 'int8' of https://github.com/TimDettmers/bitsandbytes in…
matthewdouglas Nov 5, 2024
b1c4adc
Type annotations, cleanup
matthewdouglas Nov 5, 2024
ed922b8
remove unused kernels; improved type annotations
matthewdouglas Nov 5, 2024
a93b91f
small perf optimization for single-GPU systems
matthewdouglas Nov 5, 2024
4bced86
small perf optimization for single-GPU systems
matthewdouglas Nov 5, 2024
f61d8bc
update docstrings
matthewdouglas Nov 18, 2024
eed9c3c
Improve docs and tests
matthewdouglas Nov 18, 2024
6e0a4b3
Update docstring
matthewdouglas Nov 18, 2024
161c194
Update test
matthewdouglas Nov 18, 2024
0ac1452
Merge branch 'main' into int8
matthewdouglas Nov 19, 2024
e3051fa
add benchmarking script
matthewdouglas Nov 20, 2024
56abdc2
test cleanup: add deprecated marker, move benchmarks out
matthewdouglas Nov 20, 2024
df941ec
Add int8 dequant function; misc improvements
matthewdouglas Nov 25, 2024
73f02e8
int8 matmul fallback for inner dims not divisible by 4
matthewdouglas Nov 25, 2024
ebb6797
improve register usage of kInt8VectorQuant - especially for A100/H100
matthewdouglas Nov 27, 2024
196c8e0
disable fail-fast for package build
matthewdouglas Nov 27, 2024
fa6f597
maxwell compat
matthewdouglas Nov 27, 2024
498d8de
ptxas verbose
matthewdouglas Nov 27, 2024
a2ee1c4
docs update
matthewdouglas Nov 29, 2024
ac09570
Merge branch 'main' into int8
matthewdouglas Dec 2, 2024
15f1661
doc update
matthewdouglas Dec 2, 2024
5d536c6
backward fix
matthewdouglas Dec 2, 2024
5b2348b
Bugfix sparse decomp
matthewdouglas Dec 3, 2024
bbb7063
Int8 fix for PEFT OLoRA init
matthewdouglas Dec 3, 2024
d25ebb4
Fix test for deprecated spmm_coo
matthewdouglas Dec 3, 2024
3d595f1
test improvement
matthewdouglas Dec 3, 2024
03fcabd
doc update
matthewdouglas Dec 4, 2024
582bf22
typo
matthewdouglas Dec 4, 2024
1ae7c6b
doc cleanup
matthewdouglas Dec 4, 2024
213b10b
docs
matthewdouglas Dec 4, 2024
ca6fd44
add inference benchmark script
matthewdouglas Dec 4, 2024
b8c736b
Add benchmarks, doc update
matthewdouglas Dec 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions .github/scripts/build-cuda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,21 @@ build_capability="50;52;60;61;70;75;80;86;89;90"
[[ "${cuda_version}" == 11.7.* ]] && build_capability=${build_capability%??????}
[[ "${cuda_version}" == 11.8.* ]] && build_capability=${build_capability%???}
[[ "${build_os}" = windows-* ]] && python3 -m pip install ninja
for NO_CUBLASLT in ON OFF; do
if [ "${build_os:0:6}" == ubuntu ]; then
image=nvidia/cuda:${cuda_version}-devel-ubuntu22.04
echo "Using image $image"
docker run --platform "linux/$build_arch" -i -w /src -v "$PWD:/src" "$image" sh -c \
"apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \
&& cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"${build_capability}\" -DNO_CUBLASLT=${NO_CUBLASLT} . \
&& cmake --build ."
else
pip install cmake==3.28.3
cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY="${build_capability}" -DNO_CUBLASLT=${NO_CUBLASLT} -DCMAKE_BUILD_TYPE=Release -S .
cmake --build . --config Release
fi
done

if [ "${build_os:0:6}" == ubuntu ]; then
image=nvidia/cuda:${cuda_version}-devel-ubuntu22.04
echo "Using image $image"
docker run --platform "linux/$build_arch" -i -w /src -v "$PWD:/src" "$image" sh -c \
"apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \
&& cmake -DPTXAS_VERBOSE=1 -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"${build_capability}\" . \
&& cmake --build ."
else
pip install cmake==3.28.3
cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY="${build_capability}" -DCMAKE_BUILD_TYPE=Release -S .
cmake --build . --config Release
fi


output_dir="output/${build_os}/${build_arch}"
mkdir -p "${output_dir}"
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ jobs:
##
build-shared-libs-cuda:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest]
arch: [x86_64, aarch64]
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ CMakeFiles/
bitsandbytes.dir/
Debug/
Release/
cmake-build-*/

# IDE local files
.vs/
.idea/

# Distribution / packaging
.Python
Expand Down
14 changes: 1 addition & 13 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# For MSVC: `cmake -B build . && cmake --build build --config Release`
# You can also use the following options and variables
# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, or `mps` to select the backend
# - NO_CUBLASLT: Default OFF, will skip building/linking CUBLASLT support
# - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version
# is whatever CMake finds on your path.
# - COMPUTE_CAPABILITY: Which GPU Arch/Compute codes to provide to NVCC.
Expand Down Expand Up @@ -47,10 +46,8 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda")
if(APPLE)
message(FATAL_ERROR "CUDA is not supported on macOS" )
endif()
option(NO_CUBLASLT "Disable CUBLAS" OFF)
set(BUILD_CUDA ON)
set(BUILD_MPS OFF)
message(STATUS "NO_CUBLASLT := ${NO_CUBLASLT}")
elseif(${COMPUTE_BACKEND} STREQUAL "mps")
if(NOT APPLE)
message(FATAL_ERROR "MPS is only supported on macOS" )
Expand Down Expand Up @@ -166,9 +163,6 @@ if(BUILD_CUDA)
list(APPEND SRC_FILES ${CUDA_FILES})

string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}")
if(NO_CUBLASLT)
string(APPEND BNB_OUTPUT_NAME "_nocublaslt")
endif()
add_compile_definitions(BUILD_CUDA)
elseif(BUILD_MPS)
if(NOT APPLE)
Expand Down Expand Up @@ -212,13 +206,7 @@ target_include_directories(bitsandbytes PUBLIC csrc include)

if(BUILD_CUDA)
target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cusparse)
if(NO_CUBLASLT)
target_compile_definitions(bitsandbytes PUBLIC NO_CUBLASLT)
else()
target_link_libraries(bitsandbytes PUBLIC CUDA::cublasLt)
endif()

target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cusparse)
set_target_properties(bitsandbytes
PROPERTIES
CUDA_SEPARABLE_COMPILATION ON
Expand Down
159 changes: 159 additions & 0 deletions benchmarking/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Benchmarking

## Inference
End-to-end inference benchmarking can be performed using the 🤗 [`optimum-benchmark`](https://github.com/huggingface/optimum-benchmark) library.

See the example script in
[inference_benchmark.py](inference_benchmark.py).

### Results (as of v0.45.0)

Our overall benchmarking results compared with v0.44.1 provide the following insights:
#### LLM.int8()
* **Turing/Ampere/Ada**: The observed per-token throughput is improved by 60-85%, while latency is decreased by 40-45%.
* **H100**: With our benchmarking of Llama 3.1 70B, we observed the new LLM.int8() to consistently outperform NF4 at batch size >= 8.

#### NF4/FP4
* **Turing/Ampere/Ada**: With batch size of 1, per-token throughput is _improved by 10-25%_ and per-token latency is _decreased by 10-20%_.
* **H100**: Across all batch sizes, per-token throughput is _improved by up to 28%_ and per-token latency is _decreased by up to 22%_.

Summaries with the benchmarking results are provided below.

#### NVIDIA T4 16GB
<details>
<summary>Qwen 2.5 3B Instruct</summary>

| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> | Mean Latency (s) <sub>v0.44.1</sub> | Latency Improvement | Throughput <sub>v0.44.1</sub> | Throughput Improvement |
|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------|
| FP16 | 1 | 0.0390 | 25.66 | 0.0390 | 1.00 | 25.66 | 1.000x |
| NF4 | 1 | 0.0608 | 16.45 | 0.0710 | 1.14 | 14.08 | 1.168x |
| NF4+DQ | 1 | 0.0736 | 13.58 | 0.0905 | 1.19 | 11.05 | 1.229x |
| INT8 | 1 | 0.0902 | 11.08 | 0.1609 | 1.44 | 6.21 | 1.784x |
| INT8+Decomp | 1 | 0.1672 | 5.98 | 0.2994 | 1.44 | 3.34 | 1.790x |
| FP16 | 8 | 0.0422 | 189.56 | 0.0422 | 1.00 | 189.56 | 1.000x |
| NF4 | 8 | 0.0960 | 83.37 | 0.1010 | 1.05 | 79.17 | 1.053x |
| NF4+DQ | 8 | 0.1042 | 76.80 | 0.1156 | 1.10 | 69.18 | 1.110x |
| INT8 | 8 | 0.0919 | 87.01 | 0.1640 | 1.44 | 48.78 | 1.784x |
| INT8+Decomp | 8 | 0.1812 | 44.15 | 0.3296 | 1.45 | 24.28 | 1.818x |
| FP16 | 32 | 0.0601 | 532.30 | 0.0601 | 1.00 | 532.30 | 1.000x |
| NF4 | 32 | 0.1150 | 278.32 | 0.1182 | 1.03 | 270.71 | 1.028x |
| NF4+DQ | 32 | 0.1215 | 263.36 | 0.1297 | 1.06 | 246.76 | 1.067x |
| INT8 | 32 | 0.0943 | 339.21 | 0.1640 | 1.42 | 195.14 | 1.738x |
| INT8+Decomp | 32 | 0.1912 | 167.37 | 0.3413 | 1.44 | 93.75 | 1.785x |
</details>

#### NVIDIA RTX 4090 24GB
<details>
<summary>Llama 3.1 8B</summary>

| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> | Mean Latency (s) <sub>v0.44.1</sub> | Latency Improvement | Throughput <sub>v0.44.1</sub> | Throughput Improvement |
|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------|
| BF16 | 1 | 0.0211 | 47.46 | 0.0211 | 1.00 | 47.46 | 1.000x |
| NF4 | 1 | 0.0148 | 67.71 | 0.0164 | 1.10 | 61.08 | 1.109x |
| NF4+DQ | 1 | 0.0175 | 57.08 | 0.0208 | 1.16 | 48.15 | 1.185x |
| INT8 | 1 | 0.0220 | 45.39 | 0.0395 | 1.44 | 25.32 | 1.793x |
| INT8+Decomp | 1 | 0.0449 | 22.26 | 0.0743 | 1.40 | 13.45 | 1.655x |
| BF16 | 8 | 0.0239 | 334.64 | 0.0239 | 1.00 | 334.64 | 1.000x |
| NF4 | 8 | 0.0425 | 188.08 | 0.0422 | 0.99 | 189.50 | 0.993x |
| NF4+DQ | 8 | 0.0443 | 180.68 | 0.0437 | 0.99 | 183.02 | 0.987x |
| INT8 | 8 | 0.0221 | 361.61 | 0.0389 | 1.43 | 205.82 | 1.757x |
| INT8+Decomp | 8 | 0.0478 | 164.55 | 0.0777 | 1.38 | 103.01 | 1.597x |
| BF16 | 32 | 0.0304 | 1054.35 | 0.0304 | 1.00 | 1054.35 | 1.000x |
| NF4 | 32 | 0.0461 | 694.60 | 0.0466 | 1.01 | 686.90 | 1.011x |
| NF4+DQ | 32 | 0.0471 | 678.73 | 0.0480 | 1.02 | 666.33 | 1.019x |
| INT8 | 32 | 0.0230 | 1390.54 | 0.0390 | 1.41 | 819.99 | 1.696x |
| INT8+Decomp | 32 | 0.0512 | 624.94 | 0.0835 | 1.39 | 383.18 | 1.631x |
</details>

<details>
<summary>Qwen 2.5 14B Instruct</summary>

| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> | Mean Latency (s) <sub>v0.44.1</sub> | Latency Improvement | Throughput <sub>v0.44.1</sub> | Throughput Improvement |
|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------|
| NF4 | 1 | 0.0214 | 46.74 | 0.0256 | 1.16 | 39.10 | 1.195x |
| NF4+DQ | 1 | 0.0256 | 39.03 | 0.0318 | 1.19 | 31.46 | 1.241x |
| INT8 | 1 | 0.0326 | 30.68 | 0.0596 | 1.45 | 16.79 | 1.827x |
| INT8+Decomp | 1 | 0.0648 | 15.44 | 0.1105 | 1.41 | 9.05 | 1.706x |
| NF4 | 8 | 0.0696 | 114.95 | 0.0697 | 1.00 | 114.78 | 1.001x |
| NF4+DQ | 8 | 0.0719 | 111.29 | 0.0723 | 1.01 | 110.70 | 1.005x |
| INT8 | 8 | 0.0325 | 246.22 | 0.0596 | 1.45 | 134.21 | 1.835x |
| INT8+Decomp | 8 | 0.0721 | 110.95 | 0.1201 | 1.40 | 66.62 | 1.665x |
</details>


#### NVIDIA H100 80GB SXM
<details>
<summary>Llama 3.1 8B</summary>

| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> | Mean Latency (s) <sub>v0.44.1</sub> | Latency Improvement | Throughput <sub>v0.44.1</sub> | Throughput Improvement |
|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------|
| BF16 | 1 | 0.0244 | 40.99 | 0.0244 | 1.00 | 40.99 | 1.000x |
| NF4 | 1 | 0.0331 | 30.14 | 0.0391 | 1.15 | 25.60 | 1.177x |
| NF4+DQ | 1 | 0.0411 | 24.34 | 0.0528 | 1.22 | 18.92 | 1.286x |
| INT8 | 1 | 0.0522 | 19.17 | N/A | N/A | N/A | N/A |
| INT8+Decomp | 1 | 0.0817 | 12.24 | N/A | N/A | N/A | N/A |
| BF16 | 8 | 0.0255 | 313.90 | 0.0255 | 1.00 | 313.90 | 1.000x |
| NF4 | 8 | 0.0476 | 168.05 | 0.0551 | 1.14 | 145.13 | 1.158x |
| NF4+DQ | 8 | 0.0566 | 141.27 | 0.0663 | 1.15 | 120.67 | 1.171x |
| INT8 | 8 | 0.0515 | 155.44 | N/A | N/A | N/A | N/A |
| INT8+Decomp | 8 | 0.0853 | 93.79 | N/A | N/A | N/A | N/A |
| BF16 | 32 | 0.0261 | 1227.96 | 0.0261 | 1.00 | 1227.96 | 1.000x |
| NF4 | 32 | 0.0486 | 658.65 | 0.0546 | 1.11 | 585.91 | 1.124x |
| NF4+DQ | 32 | 0.0577 | 555.06 | 0.0665 | 1.13 | 481.04 | 1.154x |
| INT8 | 32 | 0.0545 | 586.26 | N/A | N/A | N/A | N/A |
| INT8+Decomp | 32 | 0.0864 | 370.51 | N/A | N/A | N/A | N/A |
</details>

<details>
<summary>Qwen 2.5 32B Instruct</summary>

| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> |
|-------------|------------|-----------------------------------------|-----------------------------------|
| BF16 | 1 | 0.0508 | 19.67 |
| NF4 | 1 | 0.0707 | 14.14 |
| NF4+DQ | 1 | 0.0860 | 11.63 |
| INT8 | 1 | 0.1031 | 9.70 |
| INT8+Decomp | 1 | 0.1820 | 5.49 |
| BF16 | 8 | 0.0525 | 152.50 |
| NF4 | 8 | 0.1154 | 69.35 |
| NF4+DQ | 8 | 0.1209 | 66.19 |
| INT8 | 8 | 0.1078 | 74.24 |
| INT8+Decomp | 8 | 0.1958 | 40.87 |
| BF16 | 32 | 0.0547 | 584.54 |
| NF4 | 32 | 0.1246 | 256.84 |
| NF4+DQ | 32 | 0.1298 | 246.47 |
| INT8 | 32 | 0.1056 | 302.96 |
| INT8+Decomp | 32 | 0.2027 | 157.83 |
</details>

<details>
<summary>Llama 3.1 70B</summary>

| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> |
|-------------|------------|-----------------------------------------|-----------------------------------|
| NF4 | 1 | 0.0833 | 12.00 |
| NF4+DQ | 1 | 0.1052 | 9.50 |
| INT8 | 1 | 0.1294 | 7.73 |
| INT8+Decomp | 1 | 0.1985 | 5.04 |
| NF4 | 8 | 0.2348 | 34.07 |
| NF4+DQ | 8 | 0.2423 | 33.01 |
| INT8 | 8 | 0.1313 | 60.94 |
| INT8+Decomp | 8 | 0.2052 | 38.99 |
| NF4 | 32 | 0.2491 | 128.46 |
| NF4+DQ | 32 | 0.2580 | 124.04 |
| INT8 | 32 | 0.1314 | 243.45 |
| INT8+Decomp | 32 | 0.2189 | 146.19 |
</details>

#### Software Configuration
We focus on the default PyTorch CUDA backend in 🤗 [`optimum-benchmark`](https://github.com/huggingface/optimum-benchmark). We used commit [`6e6b1036`](https://github.com/huggingface/optimum-benchmark/commit/6e6b10363f3ac65926881f2c6a6113b6cefc06cd).

For all hardware configurations, we used the following dependencies:
* `transformers==4.46.3`
* `accelerate==1.1.1`
* `tokenizers==0.20.3`
* `torch==2.5.1`
* `bitsandbytes==0.44.1`
* `bitsandbytes==0.45.0.dev`

In the RTX 4090 setting, the CUDA 12.4 build of PyTorch is used. In the other settings we used the CUDA 12.1 build.
134 changes: 134 additions & 0 deletions benchmarking/inference_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
Inference benchmarking tool.

Requirements:
transformers
accelerate
bitsandbytes
optimum-benchmark

Usage: python inference_benchmark.py model_id

options:
-h, --help show this help message and exit
--configs {bf16,fp16,nf4,nf4-dq,int8,int8-decomp} [{bf16,fp16,nf4,nf4-dq,int8,int8-decomp} ...]
--bf16
--fp16
--nf4
--nf4-dq
--int8
--int8-decomp
--batches BATCHES [BATCHES ...]
--input-length INPUT_LENGTH
--out-dir OUT_DIR
"""

import argparse
from pathlib import Path

from optimum_benchmark import Benchmark, BenchmarkConfig, InferenceConfig, ProcessConfig, PyTorchConfig
from optimum_benchmark.logging_utils import setup_logging
import torch

BFLOAT16_SUPPORT = torch.cuda.get_device_capability()[0] >= 8

WEIGHTS_CONFIGS = {
"fp16": {"torch_dtype": "float16", "quantization_scheme": None, "quantization_config": {}},
"bf16": {"torch_dtype": "bfloat16", "quantization_scheme": None, "quantization_config": {}},
"nf4": {
"torch_dtype": "bfloat16" if BFLOAT16_SUPPORT else "float16",
"quantization_scheme": "bnb",
"quantization_config": {
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_use_double_quant": False,
"bnb_4bit_compute_dtype": torch.bfloat16 if BFLOAT16_SUPPORT else "float16",
},
},
"nf4-dq": {
"torch_dtype": "bfloat16" if BFLOAT16_SUPPORT else "float16",
"quantization_scheme": "bnb",
"quantization_config": {
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_use_double_quant": True,
"bnb_4bit_compute_dtype": torch.bfloat16 if BFLOAT16_SUPPORT else "float16",
},
},
"int8-decomp": {
"torch_dtype": "float16",
"quantization_scheme": "bnb",
"quantization_config": {
"load_in_8bit": True,
"llm_int8_threshold": 6.0,
},
},
"int8": {
"torch_dtype": "float16",
"quantization_scheme": "bnb",
"quantization_config": {
"load_in_8bit": True,
"llm_int8_threshold": 0.0,
},
},
}

if __name__ == "__main__":
setup_logging(level="INFO")

parser = argparse.ArgumentParser(description="bitsandbytes inference benchmark tool")

parser.add_argument("model_id", type=str, help="The model checkpoint to use.")

parser.add_argument(
"--configs",
nargs="+",
choices=["bf16", "fp16", "nf4", "nf4-dq", "int8", "int8-decomp"],
default=["nf4", "int8", "int8-decomp"],
)
parser.add_argument("--bf16", dest="configs", action="append_const", const="bf16")
parser.add_argument("--fp16", dest="configs", action="append_const", const="fp16")
parser.add_argument("--nf4", dest="configs", action="append_const", const="nf4")
parser.add_argument("--nf4-dq", dest="configs", action="append_const", const="nf4-dq")
parser.add_argument("--int8", dest="configs", action="append_const", const="int8")
parser.add_argument("--int8-decomp", dest="configs", action="append_const", const="int8-decomp")

parser.add_argument("--batches", nargs="+", type=int, default=[1, 8, 16, 32])
parser.add_argument("--input-length", type=int, default=64)

parser.add_argument("--out-dir", type=str, default="reports")

args = parser.parse_args()

out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)

for batch_size in args.batches:
print(f"Benchmarking batch size: {batch_size}")
for config in args.configs:
launcher_config = ProcessConfig(device_isolation=True, start_method="spawn")
scenario_config = InferenceConfig(
latency=True,
memory=True,
input_shapes={"batch_size": batch_size, "sequence_length": args.input_length},
)
backend_config = PyTorchConfig(
device="cuda",
device_ids="0",
device_map="auto",
no_weights=False,
model=args.model_id,
**WEIGHTS_CONFIGS[config],
)
benchmark_config = BenchmarkConfig(
name=f"benchmark-{config}-bsz{batch_size}",
scenario=scenario_config,
launcher=launcher_config,
backend=backend_config,
)

out_path = out_dir / f"benchmark_{config}_bsz{batch_size}.json"

benchmark_report = Benchmark.launch(benchmark_config)
benchmark_report.log()
benchmark_report.save_json(out_path)
Loading
Loading