Skip to content

Commit

Permalink
Adding tests on different data types
Browse files Browse the repository at this point in the history
  • Loading branch information
lowener committed May 12, 2021
1 parent 3c054d8 commit 0eb98c6
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions python/cuml/test/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,13 +1373,13 @@ def test_hinge_loss(nrows, ncols, n_info, input_type, n_classes):
stress_param(500000000)
])
@pytest.mark.parametrize("input_type", ["cudf", "cupy"])
@pytest.mark.parametrize("dtype", [cp.float32, cp.float64])
def test_kl_divergence(nfeatures, input_type, dtype):
@pytest.mark.parametrize("dtypeP", [cp.float32, cp.float64])
@pytest.mark.parametrize("dtypeQ", [cp.float32, cp.float64])
def test_kl_divergence(nfeatures, input_type, dtypeP, dtypeQ):
if not has_scipy():
pytest.skip('Skipping test_entropy_random because Scipy is missing')
pytest.skip('Skipping test_kl_divergence because Scipy is missing')

from scipy.stats import entropy as sp_entropy
# Test larger sizes to sklearn
rng = np.random.RandomState(5)

P = rng.random_sample((nfeatures))
Expand All @@ -1390,11 +1390,17 @@ def test_kl_divergence(nfeatures, input_type, dtype):
sk_res = sp_entropy(P, Q)

if input_type == "cudf":
P = cudf.DataFrame(P, dtype=dtype)
Q = cudf.DataFrame(Q, dtype=dtype)
P = cudf.DataFrame(P, dtype=dtypeP)
Q = cudf.DataFrame(Q, dtype=dtypeQ)
elif input_type == "cupy":
P = cp.asarray(P, dtype=dtype)
Q = cp.asarray(Q, dtype=dtype)
cu_res = cu_kl_divergence(P, Q)
P = cp.asarray(P, dtype=dtypeP)
Q = cp.asarray(Q, dtype=dtypeQ)

if dtypeP != dtypeQ:
with pytest.raises(TypeError):
cu_kl_divergence(P, Q, convert_dtype=False)
cu_res = cu_kl_divergence(P, Q)
else:
cu_res = cu_kl_divergence(P, Q, convert_dtype=False)

cp.testing.assert_array_almost_equal(cu_res, sk_res)

0 comments on commit 0eb98c6

Please sign in to comment.