Skip to content

Commit

Permalink
Merge pull request #191 from jakevdp:no-multithread
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 674338367
  • Loading branch information
The ml_dtypes Authors committed Sep 13, 2024
2 parents b39b73c + abe1d89 commit 7fab9cd
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 29 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ jobs:
- name: Build wheels
run: python -m cibuildwheel --output-dir wheelhouse
env:
# TODO(jakevdp): re-add 313t & free-threading support
CIBW_ARCHS_LINUX: auto aarch64
CIBW_ARCHS_MACOS: universal2
CIBW_BUILD: cp39-* cp310-* cp311-* cp312-* cp313*
CIBW_FREE_THREADED_SUPPORT: True
CIBW_BUILD: cp39-* cp310-* cp311-* cp312-* cp313-* # cp313t-*
# CIBW_FREE_THREADED_SUPPORT: True
CIBW_PRERELEASE_PYTHONS: True
CIBW_SKIP: "*musllinux* *i686* *win32* *t-win*"
CIBW_TEST_REQUIRES: absl-py pytest pytest-xdist
Expand Down
38 changes: 20 additions & 18 deletions ml_dtypes/tests/custom_float_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from absl.testing import absltest
from absl.testing import parameterized
import ml_dtypes
from multi_thread_utils import multi_threaded
# from multi_thread_utils import multi_threaded
import numpy as np

bfloat16 = ml_dtypes.bfloat16
Expand Down Expand Up @@ -221,11 +221,12 @@ def dtype_is_signed(dtype):
}


# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
# pylint: disable=g-complex-comprehension
@multi_threaded(
num_workers=3,
skip_tests=["testDiv", "testRoundTripNumpyTypes", "testRoundTripToNumpy"],
)
# @multi_threaded(
# num_workers=3,
# skip_tests=["testDiv", "testRoundTripNumpyTypes", "testRoundTripToNumpy"],
# )
@parameterized.named_parameters(
(
{"testcase_name": "_" + dtype.__name__, "float_type": dtype}
Expand Down Expand Up @@ -660,20 +661,21 @@ def testDtypeFromString(self, float_type):
]


# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
# pylint: disable=g-complex-comprehension
@multi_threaded(
num_workers=3,
skip_tests=[
"testBinaryUfunc",
"testConformNumpyComplex",
"testFloordivCornerCases",
"testDivmodCornerCases",
"testSpacing",
"testUnaryUfunc",
"testCasts",
"testLdexp",
],
)
# @multi_threaded(
# num_workers=3,
# skip_tests=[
# "testBinaryUfunc",
# "testConformNumpyComplex",
# "testFloordivCornerCases",
# "testDivmodCornerCases",
# "testSpacing",
# "testUnaryUfunc",
# "testCasts",
# "testLdexp",
# ],
# )
@parameterized.named_parameters(
(
{"testcase_name": "_" + dtype.__name__, "float_type": dtype}
Expand Down
5 changes: 3 additions & 2 deletions ml_dtypes/tests/finfo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from absl.testing import absltest
from absl.testing import parameterized
import ml_dtypes
from multi_thread_utils import multi_threaded
# from multi_thread_utils import multi_threaded
import numpy as np

ALL_DTYPES = [
Expand Down Expand Up @@ -55,7 +55,8 @@
}


@multi_threaded(num_workers=3)
# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
# @multi_threaded(num_workers=3)
class FinfoTest(parameterized.TestCase):

def assertNanEqual(self, x, y):
Expand Down
5 changes: 3 additions & 2 deletions ml_dtypes/tests/iinfo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
from absl.testing import absltest
from absl.testing import parameterized
import ml_dtypes
from multi_thread_utils import multi_threaded
# from multi_thread_utils import multi_threaded
import numpy as np


@multi_threaded(num_workers=3)
# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
# @multi_threaded(num_workers=3)
class IinfoTest(parameterized.TestCase):

def testIinfoInt2(self):
Expand Down
8 changes: 5 additions & 3 deletions ml_dtypes/tests/intn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from absl.testing import absltest
from absl.testing import parameterized
import ml_dtypes
from multi_thread_utils import multi_threaded
# from multi_thread_utils import multi_threaded
import numpy as np

int2 = ml_dtypes.int2
Expand All @@ -48,8 +48,9 @@ def ignore_warning(**kw):
yield


# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
# Tests for the Python scalar type
@multi_threaded(num_workers=3)
# @multi_threaded(num_workers=3)
class ScalarTest(parameterized.TestCase):

@parameterized.product(scalar_type=INTN_TYPES)
Expand Down Expand Up @@ -246,8 +247,9 @@ def testCanCast(self, a, b):
)


# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
# Tests for the Python scalar type
@multi_threaded(num_workers=3, skip_tests=["testBinaryUfuncs"])
# @multi_threaded(num_workers=3, skip_tests=["testBinaryUfuncs"])
class ArrayTest(parameterized.TestCase):

@parameterized.product(scalar_type=INTN_TYPES)
Expand Down
5 changes: 3 additions & 2 deletions ml_dtypes/tests/metadata_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@

from absl.testing import absltest
import ml_dtypes
from multi_thread_utils import multi_threaded
# from multi_thread_utils import multi_threaded


@multi_threaded(num_workers=3)
# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
# @multi_threaded(num_workers=3)
class CustomFloatTest(absltest.TestCase):

def test_version_matches_package_metadata(self):
Expand Down

0 comments on commit 7fab9cd

Please sign in to comment.