-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dequant sycl Kernel #1300
base: main
Are you sure you want to change the base?
Dequant sycl Kernel #1300
Conversation
82c2d23
to
f2168eb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will suggest to reuse the code between dequantization and dequantized GEMM as much as possible.
@@ -232,7 +232,7 @@ def _test(m, k, n, transpose_a, transpose_b, test_equal=True): | |||
|
|||
|
|||
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should work on Windows too, right? Can you validate?
int n, | ||
int k, | ||
const uint8_t* weight_int4, | ||
const scalar_t* ScaleAndZeros, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to unify the coding style. Although, the coding style of torch-xpu-ops
is a mess. However, we are working on it and will enable the linter ASAP. Therefore, it would be nice if you could keep the coding style consistency. @xytintel , @fengyuan14 FYI.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you provide a reference cpp?
|
||
float tmp[TileN]; | ||
bool high4 = sg_id % 2 != 0; | ||
for (int in = 0; in < TileN; in++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we do an increamental of 2 instead of 1 to remove this high4 check?
for (int in = 0; in < TileN; in += 2) {
low4 = tmp[in + 0];
high4 = tmp[in + 1];
and also since `TileN` is constexpr, it is possible to `unroll` it?
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I looked at it carefully. Because it is along the N direction, if +2, the code will be more complicated
static_assert(TileK == 1); | ||
int k = weight.size(0); | ||
int n = weight.size(1); | ||
int nsg_k = k / GroupK; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we check before doing the div for integer here:
TORCH_CHECK(k % GroupK == 0 && n % GroupN == 0);
float tmp[TileN]; | ||
bool high4 = sg_id % 2 != 0; | ||
for (int in = 0; in < TileN; in++) { | ||
int scale_offset = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Be aware of the indexing here: integer div might be slow, i am not sure that whether the compiler will do the optimization here or not. But a more promising way it to move k / block_size
and sg_id & TileK / block_size
out of the loop.
: static_cast<int8_t>((srcu8 & 0x0f) - 8) * scale + zero_point; | ||
} | ||
|
||
float tmpT[TileN]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does sycl has __shared__
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sycl::local_accessor
yes we shall update these
Use an independent dequant kernel with onednn matmul to complete the calculation of the first token
You can modify the dequant kernel to support more WOQs