Skip to content

Add support for resharding and int4 preshuffle kernel #2387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Jun 16, 2025

Summary:

This PR contains multiple changes related to fbgemm kernel integration, sorry for the large PR.

  • added transpose and cat op support, and also some custom transpose/reshape/unflatten support for resharding. for existing int4 and fp8 fbgemm support. In the future we should probably provide examples for using distributed checkpoint for resharding
  • Also added int4 preshuffle kernel support
  • Also added float8 per row scale_mm code path to FbgemmFp8Tensor
  • removed transpose_input from fbgemm config

Next:

  • rename FbgemmInt4Tensor and FbgemmFp8Tensor to remove fbgemm in the name
  • Deprecate FbgemmConfig

Test Plan:
python test/dtypes/test_fbgemm_int4.py -k test_transpose
python test/dtypes/test_fbgemm_int4.py -k test_cat
python test/dtypes/test_fbgemm_fp8.py -k test_transpose
python test/dtypes/test_fbgemm_fp8.py -k test_cat
python test/dtypes/test_int4_groupwise_preshuffle.py

Reviewers:

Subscribers:

Tasks:

Tags:

Copy link

pytorch-bot bot commented Jun 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2387

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit 799924b with merge base 6243040 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 16, 2025
@jerryzh168 jerryzh168 requested a review from drisspg June 16, 2025 20:44
@jerryzh168 jerryzh168 added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Jun 16, 2025
@jerryzh168 jerryzh168 force-pushed the fbgemm-reshard branch 2 times, most recently from 6525c15 to 9e128c1 Compare June 16, 2025 20:49
@drisspg
Copy link
Contributor

drisspg commented Jun 16, 2025

Why are these ops needed? is it for DCP?

cat_weight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
cat_weight2 = torch.cat([linear1.weight, linear2.weight], dim=1)
self.assertTrue(cat_weight1.shape, (512, 128))
self.assertTrue(cat_weight2.shape, (256, 256))
Copy link
Contributor

@drisspg drisspg Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also assert equality of bits

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

data_to_scale_dim: the dim mapping from float8_data to scale, e.g.
float8_data: (batch_size, output_channel, input_channel)
scale: (batch_size, output_channel) (since it's per row quantization)
data_to_scale_dim: {0: 0, 1: 1}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This explanation isn't very helpful / I dont know what this is doing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit confusing, removed

)

def _transpose_and_reshape(self):
"""This is added for resharding support, since the resharding logic for the model we are
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these these next two functions need to be methods or can they be implementations of the actual ops

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these should be methods, it's specific for the hack we are doing

assert len(self.shape) == 3, (
f"Only expected to be used when the Tensor is 3D, got {len(self.shape)}"
)
dim0, dim1, dim2 = self.shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont understand this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is specific to the hack, we'll transpose the weight first and then quantize, so (dim0, dim2, dim1) is the original shape

we are restoring the shape to original shape to resharding here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function needs better description of what its trying to do

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added some more descriptions, I don't have much details beyond the implementation itself actually, please let me know if it helps

@jerryzh168 jerryzh168 force-pushed the fbgemm-reshard branch 3 times, most recently from 4e562e2 to 23c46f4 Compare June 17, 2025 01:36
Groupwise int4 weight only quantization

Tensor Attributes:
packed_weight: packed int4 weight, either 2D (N, K/2) or 3D (B, N, K/2), last dimension is packed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when is a weight a 3D tensor, and why is the batch dimension in here? could you share a specific example

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is used in MoE weights I think, using the weights for bmm is a concrete example in

def test_bmm(self):

@jerryzh168 jerryzh168 force-pushed the fbgemm-reshard branch 2 times, most recently from 050b293 to c01624e Compare June 17, 2025 23:21
@jerryzh168 jerryzh168 requested review from vkuzo and drisspg June 17, 2025 23:23
vkuzo
vkuzo previously requested changes Jun 17, 2025
@jerryzh168 jerryzh168 force-pushed the fbgemm-reshard branch 3 times, most recently from 39e7ca7 to a39fd37 Compare June 18, 2025 03:40
scale = self.scale
float8_data = float8_data.unflatten(0, (num_experts, -1)).squeeze(dim=0)
scale = scale.unflatten(0, (num_experts, -1)).squeeze(dim=0)
if self.rowwise_dim == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is going on here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is just trying to figure out the rowwise quant dimension after unflatten op based on the original rowwise_dim

If the concatention dimension is the same as rowwise_dim, theoretically we should either
(1) check that scales from all tensors are equal and use the first scale
(2) dequantize and requantize
but for now we just use the first scale directly, which might have slight implication on accuaracy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can't we do 2?

Copy link
Contributor Author

@jerryzh168 jerryzh168 Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no dequantize kernels for fbgemm right now I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current implementation in llama-stack is also incorrect I think, but it still works somehow, the example output still makes sense and is similar to original output

@jerryzh168 jerryzh168 requested a review from vkuzo June 18, 2025 22:45
@jerryzh168 jerryzh168 force-pushed the fbgemm-reshard branch 3 times, most recently from 65ee2d3 to dfcda29 Compare June 19, 2025 04:00
@jerryzh168 jerryzh168 changed the title Add support for resharding for fbgemm configs Add support for resharding and int4 preshuffle kernel Jun 19, 2025
@jerryzh168 jerryzh168 force-pushed the fbgemm-reshard branch 2 times, most recently from 2852d5f to 83420a1 Compare June 19, 2025 04:19
Summary:
added transpose and cat op support, and also some custom transpose/reshape/unflatten support
for resharding.

In the future we should probably provide examples for using distributed checkpoint for resharding

Test Plan:
python test/dtypes/test_fbgemm_int4.py -k test_transpose
python test/dtypes/test_fbgemm_int4.py -k test_cat
python test/dtypes/test_fbgemm_fp8.py -k test_transpose
python test/dtypes/test_fbgemm_fp8.py -k test_cat
python test/dtypes/test_int4_groupwise_preshuffle.py

Reviewers:

Subscribers:

Tasks:

Tags:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants