Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dequant sycl Kernel #1300

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open

Dequant sycl Kernel #1300

wants to merge 15 commits into from

Conversation

sunjiweiswift
Copy link
Contributor

@sunjiweiswift sunjiweiswift commented Jan 20, 2025

Use an independent dequant kernel with onednn matmul to complete the calculation of the first token
You can modify the dequant kernel to support more WOQs

Copy link
Contributor

@airMeng airMeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will suggest to reuse the code between dequantization and dequantized GEMM as much as possible.

@@ -232,7 +232,7 @@ def _test(m, k, n, transpose_a, transpose_b, test_equal=True):


@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should work on Windows too, right? Can you validate?

src/ATen/native/xpu/sycl/Dequant_int4.cpp Show resolved Hide resolved
@sunjiweiswift sunjiweiswift changed the title Jiwei/dequant op Dequant sycl Kernel of Int4 Jan 21, 2025
@sunjiweiswift sunjiweiswift changed the title Dequant sycl Kernel of Int4 Dequant sycl Kernel Jan 22, 2025
src/ATen/native/xpu/LinearInt4.cpp Show resolved Hide resolved
src/ATen/native/xpu/sycl/Dequant_int4.cpp Show resolved Hide resolved
int n,
int k,
const uint8_t* weight_int4,
const scalar_t* ScaleAndZeros,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to unify the coding style. Although, the coding style of torch-xpu-ops is a mess. However, we are working on it and will enable the linter ASAP. Therefore, it would be nice if you could keep the coding style consistency. @xytintel , @fengyuan14 FYI.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide a reference cpp?

src/ATen/native/xpu/LinearInt4.cpp Show resolved Hide resolved

float tmp[TileN];
bool high4 = sg_id % 2 != 0;
for (int in = 0; in < TileN; in++) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do an increamental of 2 instead of 1 to remove this high4 check?

for (int in = 0; in < TileN; in += 2) {
  low4 = tmp[in + 0];
  high4 = tmp[in + 1];

and also since `TileN` is constexpr, it is possible to `unroll` it?
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I looked at it carefully. Because it is along the N direction, if +2, the code will be more complicated

static_assert(TileK == 1);
int k = weight.size(0);
int n = weight.size(1);
int nsg_k = k / GroupK;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we check before doing the div for integer here:

TORCH_CHECK(k % GroupK == 0 && n % GroupN == 0);

float tmp[TileN];
bool high4 = sg_id % 2 != 0;
for (int in = 0; in < TileN; in++) {
int scale_offset =

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Be aware of the indexing here: integer div might be slow, i am not sure that whether the compiler will do the optimization here or not. But a more promising way it to move k / block_size and sg_id & TileK / block_size out of the loop.

: static_cast<int8_t>((srcu8 & 0x0f) - 8) * scale + zero_point;
}

float tmpT[TileN];

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does sycl has __shared__?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sycl::local_accessor yes we shall update these

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants