-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add 24 compressor #167
Add 24 compressor #167
Conversation
2015e71
to
dea129e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very clean.
lgtm after tests
!!
2f69d16
to
fc4b23c
Compare
dea129e
to
68ca6c3
Compare
dd16499
to
7155e61
Compare
68ca6c3
to
6636872
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall code looks simple. I'd like to reformulate the scope, though. Specifically, I'm not following why we are restricting to just 2:4 right now when we could easily expand this to handle all sparsity cases and detect whether it is 2:4 format, some type of structured pruning, and if not any then set as unstructured. cc @dsikka
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
testing?
…ly.py Signed-off-by: Rahul Tuli <rahul@neuralmagic.com>
Add: tests for get_nested_weight_mappings Signed-off-by: Rahul Tuli <rahul@neuralmagic.com>
Signed-off-by: Rahul Tuli <rahul@neuralmagic.com>
e56bf72
to
3a6ccc8
Compare
74d6498
to
8fd469f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Only question I had was related to the changes we're making to the sparse_semi_struc to/from methods. Are we making these changes based on kernel compatibility from what was originally in vllm?
@@ -85,7 +86,7 @@ def sparse_semi_structured_from_dense_cutlass(dense): | |||
device = dense.device | |||
|
|||
meta_dtype = torch.int8 | |||
if dense.dtype == torch.int8: | |||
if dense.dtype == torch.int8 or dense.dtype == torch.float8_e4m3fn: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when is meta ever int8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made None
@@ -165,11 +166,15 @@ def sparse_semi_structured_from_dense_cutlass(dense): | |||
idxs1 = bit2 | (bit3.to(torch.int64) << 1) | |||
|
|||
if dense.dtype != torch.float: | |||
if dense.dtype == torch.float8_e4m3fn: | |||
dense_4 = dense_4.view(torch.int8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this required by the kernel only for fp8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a quirk for only fp8 dtype because certain operation are not implemented for this dtype. So we have this hack to view it as int8, this does not move the data
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lots of nits that can land later, lgtm!
src/compressed_tensors/compressors/sparse_compressors/sparse_24.py
Outdated
Show resolved
Hide resolved
|
||
|
||
@pytest.mark.parametrize("dtype", supported_dtypes()) | ||
def test_inverse_property_from_dense_then_to_dense(dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Id personally create a test for sparse_semi_structured_from_dense_cutlass and sparse_semi_structured_to_dense_cutlass. This combines both important operations, but not a blocker
@pytest.mark.parametrize("dtype", supported_dtypes()) | ||
def test_inverse_property_from_dense_then_to_dense(dtype): | ||
M, K = 1024, 1024 | ||
dense_matrix = generate_pruned_semi_structured_mat(M, K, dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also this outputs booleans in the pattern of [False, F, T, T], this is dense?
9eeede7
to
f80a45e
Compare
This PR introduces the
Sparse24Compressor
, designed for 2:4 sparse models. The implementation is based on #182 and corresponds to Part 3 of the [Design Document](https://www.notion.so/Design-Document-24-Compressor-25ac643aee604c298f2bb12a6c220861?pvs=4).Key Changes
Sparse24Compressor
for handling 2:4 sparsity in models.torch.float8e4m3
dtype.Class Hierarchy
The
Sparse24Compressor
follows the established compressor class hierarchy:File Structure
The
Sparse24Compressor
and associated logic are placed within thesparse_compressors
module:Click to expand Verification Methodology
The `Sparse24Compressor` was tested using a comprehensive script that validates its behavior through the following steps: 1. **Load Model**: An uncompressed model is loaded from the Hugging Face model hub or a local directory. 2. **Compression**: The model is compressed using `ModelCompressor`, and the compressed version is saved. 3. **Decompression**: A new base model is initialized, and the compressed weights are decompressed using `ModelCompressor.decompress`. 4. **Parameter Validation**: Parameters in the decompressed model are verified to match the original uncompressed model. 5. **Inference Check**: The decompressed model is used to generate text, ensuring correctness and functionality.Click to expand the Verification Script
Click to expand the sample output generation from decompressed model
Note: the fp8 test can only run on GPU's with cuda capability > 90
Proof that it passes on the right device: