-
Notifications
You must be signed in to change notification settings - Fork 644
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LLM.int8() Refactoring: Part 1 #1401
Conversation
… in new igemmlt implementation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -1,6 +1,5 @@ | |||
from dataclasses import dataclass | |||
from functools import reduce # Required in Python 3 | |||
import operator | |||
from math import prod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We support Python 3.8+ only, so use the builtin.
bitsandbytes/autograd/_functions.py
Outdated
@@ -245,11 +238,11 @@ class MatmulLtState: | |||
_tile_indices: Optional[torch.Tensor] = None | |||
force_no_igemmlt: bool = False | |||
CB = None | |||
CxB = None | |||
CxB = None # TODO: Deprecate/remove |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't be used anymore but I'm not sure of the side-effects of removing these properties either. Could be downstream integrations accessing them, maybe used in serialization etc. Any tips/thoughts here are welcome.
# Zero out the outliers in the transposed 8bit inputs. | ||
if CAt is not None: | ||
CAt[:, state.idx] = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We skip this for inference now as it's also not needed.
if t is None: | ||
continue # NULL pointers are fine | ||
is_paged = getattr(t, "is_paged", False) | ||
on_gpu &= t.device.type == "cuda" or is_paged | ||
if not is_paged: | ||
# NULL pointers and paged tensors are OK. | ||
if t is not None and not getattr(t, "is_paged", False): | ||
on_gpu &= t.is_cuda | ||
gpu_ids.add(t.device.index) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't specific for int8, but while I was profiling I noticed an opportunity to slightly improve some of the overhead here.
for(int i = threadIdx.x; i < 16; i++) | ||
quant_map[i] = T(datatype[i]); | ||
if (threadIdx.x < 16) | ||
quant_map[threadIdx.x] = T(__ldg(&datatype[threadIdx.x])); | ||
//for(int i = threadIdx.x; i < 16; i++) | ||
//quant_map[i] = T(__ldg(&datatype[i])); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not int8 but another small 4bit change that wanted to sneak its way in. @TimDettmers I'm just looking for a sanity check here that this makes sense.
There have now been some documentation updates both for the inline docstrings and the markdown-format public docs. Additionally, tests related to 8bit now use static shapes. Certain tests related to benchmarking have been extracted away, and others have had a new A more detailed look at benchmarking data will be provided with release materials. For now, an overview of inference benchmark results:
|
Really well done @matthewdouglas! What a thorough and well done refactor, which definitely needed some serious skill and dedication: Excellent work and results! Imo this is ready to merge now, just the docs deep dive and some small improvements that we discussed in Slack and then 🚀✨ |
Benchmark details have been added. I've also confirmed that everything is working on V100 without the separate NO_CUBLASLT build needed. |
This PR is the initial phase of a set of changes aimed at improving the LLM.int8() implementation.
@TimDettmers @Titus-von-Koeller
Primary Purpose
Enhancements
NO_CUBLASLT
build while retaining compatibility for targets below sm_75. verification in progressF.int8_vectorwise_quant
)Deprecations
The following functions from
bitsandbytes
are deprecated:The following functions from
bitsandbytes.functional
are deprecated:Further testing and benchmarking will be coming. At the moment unit tests are passing.
Next steps
Further improvement of sparse decomposition performance(Deferred to future PRs)