-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fp8 e4m3_fnuz support for rocm #2588
base: main
Are you sure you want to change the base?
Conversation
self.input_scale, | ||
self.activation_scale_ub, | ||
bias, | ||
self.dtype, | ||
) | ||
|
||
|
||
class Fp8Linear(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be cleaner to have a separate Fp8LinearRocm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe, it depends a bit on how much conditional code we end up with. We did separate FP8 Marlin for this reason.
@@ -92,9 +123,17 @@ def get_weights(self, weights: "Weights", prefix: str): | |||
.reshape(-1) | |||
.expand(w.shape[0]) | |||
) | |||
try: | |||
input_scale = weights.get_tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Weights
also has _has_tensor
maybe we should make it public and use it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same for try: [...]get_tensor
below.
@@ -72,6 +99,10 @@ def fp8_quantize( | |||
# as both required as inputs to torch._scaled_mm | |||
qweight = qweight.to(qdtype) | |||
scale = scale.float().reciprocal() | |||
|
|||
if SYSTEM == "rocm": | |||
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should wire up scale
at some point for CUDA as well.
bias=self.bias, | ||
) | ||
|
||
if type(output) is tuple and len(output) == 2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did this change between torch versions or is output for AMD different?
if type(output) is tuple and len(output) == 2: | |
if isinstance(output, tuple) and len(output) == 2: |
self.input_scale, | ||
self.activation_scale_ub, | ||
bias, | ||
self.dtype, | ||
) | ||
|
||
|
||
class Fp8Linear(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe, it depends a bit on how much conditional code we end up with. We did separate FP8 Marlin for this reason.
@@ -62,7 +62,7 @@ def from_unquant(cls, weight, bias, dtype): | |||
return cls(qweight=qweight, scales=scales.to(dtype), bias=bias) | |||
|
|||
@classmethod | |||
def from_fp8(cls, weight, scale, _input_scale, bias, dtype): | |||
def from_fp8(cls, weight, scale, _input_scale, _scale_upper_bound, bias, dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type.
These arguments get a bit messy. It's easy to mix up a tensor or a float (which was already happening here?). Maybe we should switch these to kwargs-only so that the call sites need to be explicit (+ type annotations).
What does this PR do?
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.