diff --git a/src/csrc/dtype.c b/src/csrc/dtype.c index 2033f01..414e7d4 100644 --- a/src/csrc/dtype.c +++ b/src/csrc/dtype.c @@ -455,6 +455,174 @@ quadprec_compare(void *a, void *b, void *arr) } } +/* + * Argmax function for np.argmax() + * Finds the index of the maximum element. + * NaN values are ignored unless all values are NaN. + */ +static int +quadprec_argmax(char *data, npy_intp n, npy_intp *max_ind, void *arr) +{ + PyArrayObject *array = (PyArrayObject *)arr; + QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)PyArray_DESCR(array); + npy_intp elsize = descr->base.elsize; + + *max_ind = 0; + + if (descr->backend == BACKEND_SLEEF) { + // Find first non-NaN value as initial max + npy_intp start = 0; + Sleef_quad max_val; + for (start = 0; start < n; start++) { + max_val = *(Sleef_quad *)(data + start * elsize); + if (!Sleef_iunordq1(max_val, max_val)) { + *max_ind = start; + break; + } + } + + // If all values are NaN, return 0 + if (start == n) { + *max_ind = 0; + return 0; + } + + // Find maximum + for (npy_intp i = start + 1; i < n; i++) { + Sleef_quad val = *(Sleef_quad *)(data + i * elsize); + + // Skip NaN values + if (Sleef_iunordq1(val, val)) { + continue; + } + + if (Sleef_icmpgtq1(val, max_val)) { + max_val = val; + *max_ind = i; + } + } + } + else { + // Find first non-NaN value as initial max + npy_intp start = 0; + long double max_val; + for (start = 0; start < n; start++) { + max_val = *(long double *)(data + start * elsize); + if (!isnan(max_val)) { + *max_ind = start; + break; + } + } + + // If all values are NaN, return 0 + if (start == n) { + *max_ind = 0; + return 0; + } + + // Find maximum + for (npy_intp i = start + 1; i < n; i++) { + long double val = *(long double *)(data + i * elsize); + + // Skip NaN values + if (isnan(val)) { + continue; + } + + if (val > max_val) { + max_val = val; + *max_ind = i; + } + } + } + + return 0; +} + +/* + * Argmin function for np.argmin() + * Finds the index of the minimum element. + * NaN values are ignored unless all values are NaN. + */ +static int +quadprec_argmin(char *data, npy_intp n, npy_intp *min_ind, void *arr) +{ + PyArrayObject *array = (PyArrayObject *)arr; + QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)PyArray_DESCR(array); + npy_intp elsize = descr->base.elsize; + + *min_ind = 0; + + if (descr->backend == BACKEND_SLEEF) { + // Find first non-NaN value as initial min + npy_intp start = 0; + Sleef_quad min_val; + for (start = 0; start < n; start++) { + min_val = *(Sleef_quad *)(data + start * elsize); + if (!Sleef_iunordq1(min_val, min_val)) { + *min_ind = start; + break; + } + } + + // If all values are NaN, return 0 + if (start == n) { + *min_ind = 0; + return 0; + } + + // Find minimum + for (npy_intp i = start + 1; i < n; i++) { + Sleef_quad val = *(Sleef_quad *)(data + i * elsize); + + // Skip NaN values + if (Sleef_iunordq1(val, val)) { + continue; + } + + if (Sleef_icmpltq1(val, min_val)) { + min_val = val; + *min_ind = i; + } + } + } + else { + // Find first non-NaN value as initial min + npy_intp start = 0; + long double min_val; + for (start = 0; start < n; start++) { + min_val = *(long double *)(data + start * elsize); + if (!isnan(min_val)) { + *min_ind = start; + break; + } + } + + // If all values are NaN, return 0 + if (start == n) { + *min_ind = 0; + return 0; + } + + // Find minimum + for (npy_intp i = start + 1; i < n; i++) { + long double val = *(long double *)(data + i * elsize); + + // Skip NaN values + if (isnan(val)) { + continue; + } + + if (val < min_val) { + min_val = val; + *min_ind = i; + } + } + } + + return 0; +} + static PyType_Slot QuadPrecDType_Slots[] = { {NPY_DT_ensure_canonical, &ensure_canonical}, {NPY_DT_common_instance, &common_instance}, @@ -465,6 +633,8 @@ static PyType_Slot QuadPrecDType_Slots[] = { {NPY_DT_default_descr, &quadprec_default_descr}, {NPY_DT_get_constant, &quadprec_get_constant}, {NPY_DT_PyArray_ArrFuncs_compare, &quadprec_compare}, + {NPY_DT_PyArray_ArrFuncs_argmax, &quadprec_argmax}, + {NPY_DT_PyArray_ArrFuncs_argmin, &quadprec_argmin}, {NPY_DT_PyArray_ArrFuncs_fill, &quadprec_fill}, {NPY_DT_PyArray_ArrFuncs_scanfunc, &quadprec_scanfunc}, {NPY_DT_PyArray_ArrFuncs_fromstr, &quadprec_fromstr}, diff --git a/tests/test_quaddtype.py b/tests/test_quaddtype.py index f69b944..3d495b6 100644 --- a/tests/test_quaddtype.py +++ b/tests/test_quaddtype.py @@ -5745,3 +5745,42 @@ def test_sort_algorithms(self, backend, kind): expected = np.array([1, 2, 3, 5, 8, 9], dtype=QuadPrecDType(backend=backend)) np.testing.assert_array_equal(sorted_x, expected) + +@pytest.mark.parametrize("backend", ["sleef", "longdouble"]) +def test_argmax_argmin(backend): + """Test argmax and argmin operations.""" + # Basic integers + x = np.array([3, 1, 4, 1, 5, 9, 2, 6], dtype=QuadPrecDType(backend=backend)) + assert np.argmax(x) == 5 + assert np.argmin(x) == 1 + + # With infinity + x = np.array([1, float('inf'), 2, float('-inf'), 3], dtype=QuadPrecDType(backend=backend)) + assert np.argmax(x) == 1 # +inf is max + assert np.argmin(x) == 3 # -inf is min + + # With NaN (NaN should be ignored) + x = np.array([1, float('nan'), 5, 2], dtype=QuadPrecDType(backend=backend)) + assert np.argmax(x) == 2 + assert np.argmin(x) == 0 + + # All NaN returns index 0 + x = np.array([float('nan'), float('nan')], dtype=QuadPrecDType(backend=backend)) + assert np.argmax(x) == 0 + assert np.argmin(x) == 0 + + # 2D with axis + x = np.array([[1, 5, 3], [4, 2, 6]], dtype=QuadPrecDType(backend=backend)) + assert np.argmax(x) == 5 # flattened + assert np.argmin(x) == 0 # flattened + np.testing.assert_array_equal(np.argmax(x, axis=0), [1, 0, 1]) + np.testing.assert_array_equal(np.argmin(x, axis=0), [0, 1, 0]) + np.testing.assert_array_equal(np.argmax(x, axis=1), [1, 2]) + np.testing.assert_array_equal(np.argmin(x, axis=1), [0, 1]) + + # Empty array raises ValueError + x = np.array([], dtype=QuadPrecDType(backend=backend)) + with pytest.raises(ValueError): + np.argmax(x) + with pytest.raises(ValueError): + np.argmin(x) \ No newline at end of file