diff --git a/tests/benchmarks/test_work_stealing.py b/tests/benchmarks/test_work_stealing.py index dcb9c6bc05..ae9fc47b66 100644 --- a/tests/benchmarks/test_work_stealing.py +++ b/tests/benchmarks/test_work_stealing.py @@ -1,11 +1,13 @@ import time import dask.array as da +import distributed import numpy as np import pytest from coiled import Cluster from dask import delayed, utils from distributed import Client +from packaging.version import Version from tornado.ioloop import PeriodicCallback @@ -16,11 +18,10 @@ def test_trivial_workload_should_not_cause_work_stealing(small_client): small_client.gather(futs) -# @pytest.mark.xfail( -# distributed.__version__ == "2022.6.0", -# reason="https://github.com/dask/distributed/issues/6624", -# ) -@pytest.mark.skip("https://github.com/coiled/coiled-runtime/issues/336") +@pytest.mark.xfail( + Version(distributed.__version__) < Version("2022.6.1"), + reason="https://github.com/dask/distributed/issues/6624", +) def test_work_stealing_on_scaling_up( test_name_uuid, upload_cluster_dump, benchmark_all ): @@ -30,6 +31,7 @@ def test_work_stealing_on_scaling_up( worker_vm_types=["t3.medium"], scheduler_vm_types=["t3.xlarge"], wait_for_workers=True, + package_sync=True, ) as cluster: with Client(cluster) as client: with upload_cluster_dump(client, cluster), benchmark_all(client):