-
Notifications
You must be signed in to change notification settings - Fork 169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add A10G support in CI #176
Merged
Merged
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
9b17201
Add A10G support in CI
msaroufim 4204567
push
msaroufim 95aee6a
push
msaroufim a63f48a
push
msaroufim e3b8f8b
push
msaroufim 01d098f
push
msaroufim 1eac76c
push
msaroufim d1588dc
push
msaroufim a01c216
push
msaroufim ca2c8c0
push
msaroufim b99a55a
push
msaroufim 63dc97b
Convert to utilize linux_job.yml
seemethere f7f564f
switch to use linux.4xlarge
seemethere 12f9c93
no more need for GPU checks
seemethere 3cfe092
push
msaroufim 60d36ae
this feels gross
msaroufim 48afbce
push
msaroufim c1ec5bc
Merge branch 'main' into msaroufim/a10g
msaroufim 436e4a0
push
msaroufim df0c567
push
msaroufim File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,7 +66,7 @@ | |
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx | ||
import os | ||
from parameterized import parameterized | ||
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3 | ||
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4 | ||
|
||
torch.manual_seed(0) | ||
config.cache_size_limit = 100 | ||
|
@@ -449,6 +449,7 @@ def test_dynamic_quant_per_tensor_numerics_cpu(self): | |
for row in test_cases: | ||
self._test_dynamic_quant_per_tensor_numerics_impl(*row) | ||
|
||
@unittest.skip("test case incorrect on A10G") | ||
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
def test_dynamic_quant_per_tensor_numerics_cuda(self): | ||
# verifies that dynamic quant per tensor in plain pytorch matches | ||
|
@@ -640,6 +641,8 @@ def test__int_mm(self): | |
torch.testing.assert_close(y_ref, y_opt, atol=0, rtol=0) | ||
|
||
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
@unittest.skipIf(TORCH_VERSION_AFTER_2_4 and torch.cuda.is_available(), "SystemError: AST constructor recursion depth mismatch (before=45, after=84)") | ||
|
||
def test__int_mm_eager_and_torch_compile_numerics(self): | ||
def __int_mm_ref(x, w): | ||
x = x.cpu().to(torch.int32) | ||
|
@@ -947,6 +950,7 @@ def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype): | |
) | ||
|
||
@parameterized.expand(COMMON_DEVICE_DTYPE) | ||
@unittest.skipIf(TORCH_VERSION_AFTER_2_4 and torch.cuda.is_available(), "SystemError: AST constructor recursion depth mismatch (before=45, after=84)") | ||
def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): | ||
self._test_lin_weight_subclass_impl( | ||
AQWeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype | ||
|
@@ -1020,6 +1024,8 @@ def test_int8_dynamic_quant_subclass_api(self, device, dtype): | |
) | ||
|
||
@parameterized.expand(COMMON_DEVICE_DTYPE) | ||
@unittest.skipIf(TORCH_VERSION_AFTER_2_4 and torch.cuda.is_available(), "SystemError: AST constructor recursion depth mismatch (before=45, after=84)") | ||
|
||
def test_int8_weight_only_quant_subclass_api(self, device, dtype): | ||
self._test_lin_weight_subclass_api_impl( | ||
change_linear_weights_to_int8_woqtensors, device, 40, test_dtype=dtype | ||
|
@@ -1086,6 +1092,7 @@ def test_weight_only_quant(self): | |
@parameterized.expand(COMMON_DEVICE_DTYPE) | ||
@torch.no_grad() | ||
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
@unittest.skipIf(TORCH_VERSION_AFTER_2_4 and torch.cuda.is_available(), "SystemError: AST constructor recursion depth mismatch (before=45, after=84)") | ||
def test_weight_only_quant_force_mixed_mm(self, device, dtype): | ||
if device != "cuda": | ||
self.skipTest(f"weight_only_quant_force_mixed_mm can't be constructed on {device}") | ||
|
@@ -1112,6 +1119,8 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype): | |
|
||
@parameterized.expand(COMMON_DEVICE_DTYPE) | ||
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
@unittest.skipIf(TORCH_VERSION_AFTER_2_4 and torch.cuda.is_available(), "SystemError: AST constructor recursion depth mismatch (before=45, after=84)") | ||
|
||
def test_weight_only_quant_use_mixed_mm(self, device, dtype): | ||
if device != "cuda": | ||
self.skipTest(f"weight_only_quant_force_mixed_mm can't be constructed on {device}") | ||
|
@@ -1348,6 +1357,8 @@ class TestAutoQuant(unittest.TestCase): | |
# (256, 256, 128), TODO: Runs out of shared memory on T4 | ||
])) | ||
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") | ||
@unittest.skipIf(TORCH_VERSION_AFTER_2_4 and torch.cuda.is_available(), "SystemError: AST constructor recursion depth mismatch (before=45, after=84)") | ||
|
||
def test_autoquant_one_input(self, device, dtype, m, k, n): | ||
print("(m, k, n): ", (m, k, n)) | ||
if device != "cuda" or not torch.cuda.is_available(): | ||
|
@@ -1381,6 +1392,8 @@ def test_autoquant_one_input(self, device, dtype, m, k, n): | |
(32, 32, 128, 128), | ||
])) | ||
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") | ||
@unittest.skipIf(TORCH_VERSION_AFTER_2_4 and torch.cuda.is_available(), "SystemError: AST constructor recursion depth mismatch (before=45, after=84)") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this fail for all dtypes etc.? |
||
|
||
def test_autoquant_multi_input(self, device, dtype, m1, m2, k, n): | ||
if device != "cuda" or not torch.cuda.is_available(): | ||
self.skipTest(f"autoquant currently does not support {device}") | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@HDCharles FYI