Skip to content

Commit

Permalink
add require_triton and enable test_dynamo work on xpu (#2878)
Browse files Browse the repository at this point in the history
  • Loading branch information
faaany authored Jul 3, 2024
1 parent fec1170 commit 3a02754
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 5 deletions.
10 changes: 9 additions & 1 deletion src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
is_torch_xla_available,
is_torchvision_available,
is_transformers_available,
is_triton_available,
is_wandb_available,
is_xpu_available,
str_to_bool,
Expand Down Expand Up @@ -213,7 +214,7 @@ def require_transformers(test_case):

def require_timm(test_case):
"""
Decorator marking a test that requires transformers. These tests are skipped when they are not.
Decorator marking a test that requires timm. These tests are skipped when they are not.
"""
return unittest.skipUnless(is_timm_available(), "test requires the timm library")(test_case)

Expand All @@ -225,6 +226,13 @@ def require_torchvision(test_case):
return unittest.skipUnless(is_torchvision_available(), "test requires the torchvision library")(test_case)


def require_triton(test_case):
"""
Decorator marking a test that requires triton. These tests are skipped when they are not.
"""
return unittest.skipUnless(is_triton_available(), "test requires the triton library")(test_case)


def require_schedulefree(test_case):
"""
Decorator marking a test that requires schedulefree. These tests are skipped when they are not.
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
is_torchvision_available,
is_transformer_engine_available,
is_transformers_available,
is_triton_available,
is_wandb_available,
is_xpu_available,
)
Expand Down
4 changes: 4 additions & 0 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ def is_timm_available():
return _is_package_available("timm")


def is_triton_available():
return _is_package_available("triton")


def is_aim_available():
package_exists = _is_package_available("aim")
if package_exists:
Expand Down
9 changes: 5 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@

from accelerate.state import PartialState
from accelerate.test_utils.testing import (
require_cuda,
require_huggingface_suite,
require_non_cpu,
require_non_torch_xla,
require_torch_min_version,
require_tpu,
require_triton,
torch_device,
)
from accelerate.test_utils.training import RegressionModel
Expand Down Expand Up @@ -190,15 +190,16 @@ def test_can_undo_fp16_conversion(self):
model = extract_model_from_parallel(model, keep_fp32_wrapper=False)
_ = pickle.dumps(model)

@require_cuda
@require_triton
@require_non_cpu
@require_torch_min_version(version="2.0")
def test_dynamo(self):
model = RegressionModel()
model._original_forward = model.forward
model.forward = torch.cuda.amp.autocast(dtype=torch.float16)(model.forward)
model.forward = torch.autocast(device_type=torch_device, dtype=torch.float16)(model.forward)
model.forward = convert_outputs_to_fp32(model.forward)
model.forward = torch.compile(model.forward, backend="inductor")
inputs = torch.randn(4, 10).cuda()
inputs = torch.randn(4, 10).to(torch_device)
_ = model(inputs)

def test_extract_model(self):
Expand Down

0 comments on commit 3a02754

Please sign in to comment.