Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions src/csrc/dtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,174 @@ quadprec_compare(void *a, void *b, void *arr)
}
}

/*
* Argmax function for np.argmax()
* Finds the index of the maximum element.
* NaN values are ignored unless all values are NaN.
*/
static int
quadprec_argmax(char *data, npy_intp n, npy_intp *max_ind, void *arr)
{
PyArrayObject *array = (PyArrayObject *)arr;
QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)PyArray_DESCR(array);
npy_intp elsize = descr->base.elsize;

*max_ind = 0;

if (descr->backend == BACKEND_SLEEF) {
// Find first non-NaN value as initial max
npy_intp start = 0;
Sleef_quad max_val;
for (start = 0; start < n; start++) {
max_val = *(Sleef_quad *)(data + start * elsize);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That feels wrong (probably does the right thing but ...). What I meant was, after the all-NaN check, load the maxval from the start index before you enter into the loop that finds the total maximum

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though then again, we're already loading the values ... so this does save one read ... so maybe it's fine or even better?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That'll be actually worse

  • Uses 1 extra temporary variable val in the NaN-finding loop (it is must to cast the data to quad pointer)
  • After NaN loop, needs another read: max_val = *(Sleef_quad *)(data + start * elsize)
  • Extra memory access (even if cached, it's still an instruction)

if (!Sleef_iunordq1(max_val, max_val)) {
*max_ind = start;
break;
}
}

// If all values are NaN, return 0
if (start == n) {
*max_ind = 0;
return 0;
}

// Find maximum
for (npy_intp i = start + 1; i < n; i++) {
Sleef_quad val = *(Sleef_quad *)(data + i * elsize);

// Skip NaN values
if (Sleef_iunordq1(val, val)) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It maybe not preferable to use these routines directly in code since we already have their abstractions inside ops.hpp, using them wherever requires might be better as something goes off then just fix that abstraction.
But those abstractions uses C++ features like templates so we might have to rename this file or any other as .cpp

continue;
}

if (Sleef_icmpgtq1(val, max_val)) {
max_val = val;
*max_ind = i;
}
}
}
else {
// Find first non-NaN value as initial max
npy_intp start = 0;
long double max_val;
for (start = 0; start < n; start++) {
max_val = *(long double *)(data + start * elsize);
if (!isnan(max_val)) {
*max_ind = start;
break;
}
}

// If all values are NaN, return 0
if (start == n) {
*max_ind = 0;
return 0;
}

// Find maximum
for (npy_intp i = start + 1; i < n; i++) {
long double val = *(long double *)(data + i * elsize);

// Skip NaN values
if (isnan(val)) {
continue;
}

if (val > max_val) {
max_val = val;
*max_ind = i;
}
}
}

return 0;
}

/*
* Argmin function for np.argmin()
* Finds the index of the minimum element.
* NaN values are ignored unless all values are NaN.
*/
static int
quadprec_argmin(char *data, npy_intp n, npy_intp *min_ind, void *arr)
{
PyArrayObject *array = (PyArrayObject *)arr;
QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)PyArray_DESCR(array);
npy_intp elsize = descr->base.elsize;

*min_ind = 0;

if (descr->backend == BACKEND_SLEEF) {
// Find first non-NaN value as initial min
npy_intp start = 0;
Sleef_quad min_val;
for (start = 0; start < n; start++) {
min_val = *(Sleef_quad *)(data + start * elsize);
if (!Sleef_iunordq1(min_val, min_val)) {
*min_ind = start;
break;
}
}

// If all values are NaN, return 0
if (start == n) {
*min_ind = 0;
return 0;
}

// Find minimum
for (npy_intp i = start + 1; i < n; i++) {
Sleef_quad val = *(Sleef_quad *)(data + i * elsize);

// Skip NaN values
if (Sleef_iunordq1(val, val)) {
continue;
}

if (Sleef_icmpltq1(val, min_val)) {
min_val = val;
*min_ind = i;
}
}
}
else {
// Find first non-NaN value as initial min
npy_intp start = 0;
long double min_val;
for (start = 0; start < n; start++) {
min_val = *(long double *)(data + start * elsize);
if (!isnan(min_val)) {
*min_ind = start;
break;
}
}

// If all values are NaN, return 0
if (start == n) {
*min_ind = 0;
return 0;
}

// Find minimum
for (npy_intp i = start + 1; i < n; i++) {
long double val = *(long double *)(data + i * elsize);

// Skip NaN values
if (isnan(val)) {
continue;
}

if (val < min_val) {
min_val = val;
*min_ind = i;
}
}
}

return 0;
}

static PyType_Slot QuadPrecDType_Slots[] = {
{NPY_DT_ensure_canonical, &ensure_canonical},
{NPY_DT_common_instance, &common_instance},
Expand All @@ -465,6 +633,8 @@ static PyType_Slot QuadPrecDType_Slots[] = {
{NPY_DT_default_descr, &quadprec_default_descr},
{NPY_DT_get_constant, &quadprec_get_constant},
{NPY_DT_PyArray_ArrFuncs_compare, &quadprec_compare},
{NPY_DT_PyArray_ArrFuncs_argmax, &quadprec_argmax},
{NPY_DT_PyArray_ArrFuncs_argmin, &quadprec_argmin},
{NPY_DT_PyArray_ArrFuncs_fill, &quadprec_fill},
{NPY_DT_PyArray_ArrFuncs_scanfunc, &quadprec_scanfunc},
{NPY_DT_PyArray_ArrFuncs_fromstr, &quadprec_fromstr},
Expand Down
39 changes: 39 additions & 0 deletions tests/test_quaddtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -5745,3 +5745,42 @@ def test_sort_algorithms(self, backend, kind):
expected = np.array([1, 2, 3, 5, 8, 9], dtype=QuadPrecDType(backend=backend))
np.testing.assert_array_equal(sorted_x, expected)


@pytest.mark.parametrize("backend", ["sleef", "longdouble"])
def test_argmax_argmin(backend):
"""Test argmax and argmin operations."""
# Basic integers
x = np.array([3, 1, 4, 1, 5, 9, 2, 6], dtype=QuadPrecDType(backend=backend))
assert np.argmax(x) == 5
assert np.argmin(x) == 1

# With infinity
x = np.array([1, float('inf'), 2, float('-inf'), 3], dtype=QuadPrecDType(backend=backend))
assert np.argmax(x) == 1 # +inf is max
assert np.argmin(x) == 3 # -inf is min

# With NaN (NaN should be ignored)
x = np.array([1, float('nan'), 5, 2], dtype=QuadPrecDType(backend=backend))
assert np.argmax(x) == 2
assert np.argmin(x) == 0

# All NaN returns index 0
x = np.array([float('nan'), float('nan')], dtype=QuadPrecDType(backend=backend))
assert np.argmax(x) == 0
assert np.argmin(x) == 0

# 2D with axis
x = np.array([[1, 5, 3], [4, 2, 6]], dtype=QuadPrecDType(backend=backend))
assert np.argmax(x) == 5 # flattened
assert np.argmin(x) == 0 # flattened
np.testing.assert_array_equal(np.argmax(x, axis=0), [1, 0, 1])
np.testing.assert_array_equal(np.argmin(x, axis=0), [0, 1, 0])
np.testing.assert_array_equal(np.argmax(x, axis=1), [1, 2])
np.testing.assert_array_equal(np.argmin(x, axis=1), [0, 1])

# Empty array raises ValueError
x = np.array([], dtype=QuadPrecDType(backend=backend))
with pytest.raises(ValueError):
np.argmax(x)
with pytest.raises(ValueError):
np.argmin(x)