Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLM.int8() Refactoring: Part 1 #1401

Merged
merged 72 commits into from
Dec 5, 2024
Merged

LLM.int8() Refactoring: Part 1 #1401

merged 72 commits into from
Dec 5, 2024

Conversation

matthewdouglas
Copy link
Member

@matthewdouglas matthewdouglas commented Oct 24, 2024

This PR is the initial phase of a set of changes aimed at improving the LLM.int8() implementation.

@TimDettmers @Titus-von-Koeller

Primary Purpose

Enhancements

  • Removes the usage of Turing and Ampere specific memory layouts while retaining compatibility across sm_75 through sm_89.
    • Simplifies the code and surface area needing to be maintained.
    • Reduced overhead by removing layout transformation operations.
  • Removes the separate NO_CUBLASLT build while retaining compatibility for targets below sm_75. verification in progress
    • This simplifies building and packaging, and trims the size of binary wheels in ~half.
  • Support for CUDA Graph tracing to bring parity with 4bit.
  • Improved kernels for inference:
    • Fused kernel for activation scale calibration and quantization. (Exposed as op F.int8_vectorwise_quant)
    • Other kernels simplified to operate with row-major data.
  • Makes many unit tests more reliable with increased determinism.

Deprecations

The following functions from bitsandbytes are deprecated:

mm_cublas
bmm_cublas
matmul_cublas

The following functions from bitsandbytes.functional are deprecated:

_mul
arange
dequant_min_max
dequantize_no_absmax
extract_outliers
get_special_format_str
get_tensor_stream (moved to internal API)
get_transform_buffer
get_transform_func
mm_dequant (replacement: int8_mm_dequant)
igemmlt (replacement: int8_linear_matmul)
nvidia_transform
post_call
pre_call
transform
quantize_no_absmax
vectorwise_dequant
vectorwise_quant (~replacement: int8_vectorwise_quant)
vectorwise_mm_dequant (~replacement: int8_mm_dequant)

Further testing and benchmarking will be coming. At the moment unit tests are passing.

Next steps

  • Clean up and reorganize unit tests
  • Documentation for public APIs
  • Ensure fallback path for shapes that don't work well with cuBLASLt (i.e. m/k not multiples of 4).
  • Add an int8 dequantize op
  • Further improvement of sparse decomposition performance (Deferred to future PRs)
  • Conduct profiling, benchmarks, and evaluations
    • Build benchmark/evaluation scripts
    • Prepare analysis of results

Copy link

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -1,6 +1,5 @@
from dataclasses import dataclass
from functools import reduce # Required in Python 3
import operator
from math import prod
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We support Python 3.8+ only, so use the builtin.

@@ -245,11 +238,11 @@ class MatmulLtState:
_tile_indices: Optional[torch.Tensor] = None
force_no_igemmlt: bool = False
CB = None
CxB = None
CxB = None # TODO: Deprecate/remove
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't be used anymore but I'm not sure of the side-effects of removing these properties either. Could be downstream integrations accessing them, maybe used in serialization etc. Any tips/thoughts here are welcome.

Comment on lines +345 to +347
# Zero out the outliers in the transposed 8bit inputs.
if CAt is not None:
CAt[:, state.idx] = 0
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We skip this for inference now as it's also not needed.

Comment on lines -439 to +431
if t is None:
continue # NULL pointers are fine
is_paged = getattr(t, "is_paged", False)
on_gpu &= t.device.type == "cuda" or is_paged
if not is_paged:
# NULL pointers and paged tensors are OK.
if t is not None and not getattr(t, "is_paged", False):
on_gpu &= t.is_cuda
gpu_ids.add(t.device.index)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't specific for int8, but while I was profiling I noticed an opportunity to slightly improve some of the overhead here.

csrc/kernels.cu Outdated Show resolved Hide resolved
Comment on lines -3528 to +3574
for(int i = threadIdx.x; i < 16; i++)
quant_map[i] = T(datatype[i]);
if (threadIdx.x < 16)
quant_map[threadIdx.x] = T(__ldg(&datatype[threadIdx.x]));
//for(int i = threadIdx.x; i < 16; i++)
//quant_map[i] = T(__ldg(&datatype[i]));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not int8 but another small 4bit change that wanted to sneak its way in. @TimDettmers I'm just looking for a sanity check here that this makes sense.

@matthewdouglas matthewdouglas added the enhancement New feature or request label Oct 24, 2024
@matthewdouglas matthewdouglas marked this pull request as ready for review November 25, 2024 17:37
@matthewdouglas
Copy link
Member Author

There have now been some documentation updates both for the inline docstrings and the markdown-format public docs.

Additionally, tests related to 8bit now use static shapes. Certain tests related to benchmarking have been extracted away, and others have had a new deprecated marker applied where appropriate.

A more detailed look at benchmarking data will be provided with release materials. For now, an overview of inference benchmark results:

  • INT8
    • On T4 and 4090, the per-token throughput is improved by 60-85% and per-token latency is decreased by 40-45%.
    • H100 is now supported. With Llama 3.1 70B and batch size >= 8, INT8 is consistently faster than NF4.
  • NF4:
    • On T4 and 4090, with batch size of 1, per-token throughput is improved by 10-25% and per-token latency is decreased by 10-20%.
    • On H100, across all batch sizes, per-token throughput is improved by up to 28% and per-token latency is decreased by up to 22%.

@Titus-von-Koeller
Copy link
Collaborator

Really well done @matthewdouglas! What a thorough and well done refactor, which definitely needed some serious skill and dedication: Excellent work and results!

Imo this is ready to merge now, just the docs deep dive and some small improvements that we discussed in Slack and then 🚀✨

@matthewdouglas
Copy link
Member Author

Benchmark details have been added. I've also confirmed that everything is working on V100 without the separate NO_CUBLASLT build needed.

@matthewdouglas matthewdouglas merged commit 81e6345 into main Dec 5, 2024
60 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request Medium risk Risk of bugs in transformers and other libraries
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug] Exception: cublasLt ran into an error! during fine-tuning LLM in 8bit mode
3 participants