-
Notifications
You must be signed in to change notification settings - Fork 287
Add support for resharding and int4 preshuffle kernel #2387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2387
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New FailuresAs of commit 799924b with merge base 6243040 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
6525c15
to
9e128c1
Compare
Why are these ops needed? is it for DCP? |
test/dtypes/test_fbgemm_fp8.py
Outdated
cat_weight1 = torch.cat([linear1.weight, linear2.weight], dim=0) | ||
cat_weight2 = torch.cat([linear1.weight, linear2.weight], dim=1) | ||
self.assertTrue(cat_weight1.shape, (512, 128)) | ||
self.assertTrue(cat_weight2.shape, (256, 256)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you also assert equality of bits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure
torchao/dtypes/fbgemm_fp8_tensor.py
Outdated
data_to_scale_dim: the dim mapping from float8_data to scale, e.g. | ||
float8_data: (batch_size, output_channel, input_channel) | ||
scale: (batch_size, output_channel) (since it's per row quantization) | ||
data_to_scale_dim: {0: 0, 1: 1} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This explanation isn't very helpful / I dont know what this is doing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a bit confusing, removed
) | ||
|
||
def _transpose_and_reshape(self): | ||
"""This is added for resharding support, since the resharding logic for the model we are |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do these these next two functions need to be methods or can they be implementations of the actual ops
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these should be methods, it's specific for the hack we are doing
assert len(self.shape) == 3, ( | ||
f"Only expected to be used when the Tensor is 3D, got {len(self.shape)}" | ||
) | ||
dim0, dim1, dim2 = self.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont understand this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is specific to the hack, we'll transpose the weight first and then quantize, so (dim0, dim2, dim1) is the original shape
we are restoring the shape to original shape to resharding here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function needs better description of what its trying to do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added some more descriptions, I don't have much details beyond the implementation itself actually, please let me know if it helps
4e562e2
to
23c46f4
Compare
Groupwise int4 weight only quantization | ||
|
||
Tensor Attributes: | ||
packed_weight: packed int4 weight, either 2D (N, K/2) or 3D (B, N, K/2), last dimension is packed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when is a weight a 3D tensor, and why is the batch dimension in here? could you share a specific example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is used in MoE weights I think, using the weights for bmm is a concrete example in
ao/test/dtypes/test_fbgemm_fp8.py
Line 116 in e29b9bd
def test_bmm(self): |
050b293
to
c01624e
Compare
39e7ca7
to
a39fd37
Compare
scale = self.scale | ||
float8_data = float8_data.unflatten(0, (num_experts, -1)).squeeze(dim=0) | ||
scale = scale.unflatten(0, (num_experts, -1)).squeeze(dim=0) | ||
if self.rowwise_dim == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is going on here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is just trying to figure out the rowwise quant dimension after unflatten op based on the original rowwise_dim
If the concatention dimension is the same as rowwise_dim, theoretically we should either | ||
(1) check that scales from all tensors are equal and use the first scale | ||
(2) dequantize and requantize | ||
but for now we just use the first scale directly, which might have slight implication on accuaracy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why can't we do 2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no dequantize kernels for fbgemm right now I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
current implementation in llama-stack is also incorrect I think, but it still works somehow, the example output still makes sense and is similar to original output
65ee2d3
to
dfcda29
Compare
2852d5f
to
83420a1
Compare
Summary: added transpose and cat op support, and also some custom transpose/reshape/unflatten support for resharding. In the future we should probably provide examples for using distributed checkpoint for resharding Test Plan: python test/dtypes/test_fbgemm_int4.py -k test_transpose python test/dtypes/test_fbgemm_int4.py -k test_cat python test/dtypes/test_fbgemm_fp8.py -k test_transpose python test/dtypes/test_fbgemm_fp8.py -k test_cat python test/dtypes/test_int4_groupwise_preshuffle.py Reviewers: Subscribers: Tasks: Tags:
83420a1
to
799924b
Compare
Summary:
This PR contains multiple changes related to fbgemm kernel integration, sorry for the large PR.
Next:
Test Plan:
python test/dtypes/test_fbgemm_int4.py -k test_transpose
python test/dtypes/test_fbgemm_int4.py -k test_cat
python test/dtypes/test_fbgemm_fp8.py -k test_transpose
python test/dtypes/test_fbgemm_fp8.py -k test_cat
python test/dtypes/test_int4_groupwise_preshuffle.py
Reviewers:
Subscribers:
Tasks:
Tags: