Skip to content

Commit

Permalink
test: fix new cublas test
Browse files Browse the repository at this point in the history
  • Loading branch information
steven-murray committed Dec 4, 2023
1 parent 2fe40e0 commit eab2dff
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 10 deletions.
3 changes: 2 additions & 1 deletion src/matvis/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from pyuvdata import UVBeam
from typing import Callable, Optional

from . import _cublas as cb
from . import conversions
from ._utils import ceildiv
from ._uvbeam_to_raw import uvbeam_to_azza_grid
Expand Down Expand Up @@ -46,6 +45,8 @@
cublasZgemm,
)

from . import _cublas as cb

HAVE_CUDA = True

except ImportError:
Expand Down
14 changes: 5 additions & 9 deletions tests/test_cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@

import pytest

pytest.importorskip("pycuda")

import numpy as np
from pycuda.gpuarray import to_gpu

from matvis import _cublas as cb


@pytest.mark.parameterize(
"dtype", [np.float32, np.float64, np.complex64, np.complex128]
)
@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.complex64, np.complex128])
def test_dotc(dtype):
"""Test the dotc function."""
a = np.random.randn(10).astype(dtype)
Expand All @@ -24,9 +24,7 @@ def test_dotc(dtype):
assert np.allclose(c, np.vdot(a, b))


@pytest.mark.parameterize(
"dtype", [np.float32, np.float64, np.complex64, np.complex128]
)
@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.complex64, np.complex128])
def test_gemm(dtype):
"""Test the gemm function."""
a = np.random.randn(10, 12).astype(dtype)
Expand All @@ -39,9 +37,7 @@ def test_gemm(dtype):
assert np.allclose(c, np.dot(a, b))


@pytest.mark.parameterize(
"dtype", [np.float32, np.float64, np.complex64, np.complex128]
)
@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.complex64, np.complex128])
def test_zz(dtype):
"""Test the zz function."""
a = np.random.randn(10, 15).astype(dtype)
Expand Down

0 comments on commit eab2dff

Please sign in to comment.