-
-
Notifications
You must be signed in to change notification settings - Fork 4
FEAT: Implementing argmax & argmin PyArray slots for quaddtype #50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -455,6 +455,174 @@ quadprec_compare(void *a, void *b, void *arr) | |
| } | ||
| } | ||
|
|
||
| /* | ||
| * Argmax function for np.argmax() | ||
| * Finds the index of the maximum element. | ||
| * NaN values are ignored unless all values are NaN. | ||
| */ | ||
| static int | ||
| quadprec_argmax(char *data, npy_intp n, npy_intp *max_ind, void *arr) | ||
| { | ||
| PyArrayObject *array = (PyArrayObject *)arr; | ||
| QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)PyArray_DESCR(array); | ||
| npy_intp elsize = descr->base.elsize; | ||
|
|
||
| *max_ind = 0; | ||
|
|
||
| if (descr->backend == BACKEND_SLEEF) { | ||
| // Find first non-NaN value as initial max | ||
| npy_intp start = 0; | ||
| Sleef_quad max_val; | ||
| for (start = 0; start < n; start++) { | ||
| max_val = *(Sleef_quad *)(data + start * elsize); | ||
| if (!Sleef_iunordq1(max_val, max_val)) { | ||
| *max_ind = start; | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| // If all values are NaN, return 0 | ||
| if (start == n) { | ||
| *max_ind = 0; | ||
| return 0; | ||
| } | ||
|
|
||
| // Find maximum | ||
| for (npy_intp i = start + 1; i < n; i++) { | ||
| Sleef_quad val = *(Sleef_quad *)(data + i * elsize); | ||
|
|
||
| // Skip NaN values | ||
| if (Sleef_iunordq1(val, val)) { | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It maybe not preferable to use these routines directly in code since we already have their abstractions inside ops.hpp, using them wherever requires might be better as something goes off then just fix that abstraction. |
||
| continue; | ||
| } | ||
|
|
||
| if (Sleef_icmpgtq1(val, max_val)) { | ||
| max_val = val; | ||
| *max_ind = i; | ||
| } | ||
| } | ||
SwayamInSync marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
| else { | ||
| // Find first non-NaN value as initial max | ||
| npy_intp start = 0; | ||
| long double max_val; | ||
| for (start = 0; start < n; start++) { | ||
| max_val = *(long double *)(data + start * elsize); | ||
| if (!isnan(max_val)) { | ||
| *max_ind = start; | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| // If all values are NaN, return 0 | ||
| if (start == n) { | ||
| *max_ind = 0; | ||
| return 0; | ||
| } | ||
|
|
||
| // Find maximum | ||
| for (npy_intp i = start + 1; i < n; i++) { | ||
| long double val = *(long double *)(data + i * elsize); | ||
|
|
||
| // Skip NaN values | ||
| if (isnan(val)) { | ||
| continue; | ||
| } | ||
|
|
||
| if (val > max_val) { | ||
| max_val = val; | ||
| *max_ind = i; | ||
| } | ||
| } | ||
SwayamInSync marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| return 0; | ||
| } | ||
|
|
||
| /* | ||
| * Argmin function for np.argmin() | ||
| * Finds the index of the minimum element. | ||
| * NaN values are ignored unless all values are NaN. | ||
| */ | ||
| static int | ||
| quadprec_argmin(char *data, npy_intp n, npy_intp *min_ind, void *arr) | ||
| { | ||
| PyArrayObject *array = (PyArrayObject *)arr; | ||
| QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)PyArray_DESCR(array); | ||
| npy_intp elsize = descr->base.elsize; | ||
|
|
||
| *min_ind = 0; | ||
|
|
||
| if (descr->backend == BACKEND_SLEEF) { | ||
| // Find first non-NaN value as initial min | ||
| npy_intp start = 0; | ||
| Sleef_quad min_val; | ||
| for (start = 0; start < n; start++) { | ||
| min_val = *(Sleef_quad *)(data + start * elsize); | ||
| if (!Sleef_iunordq1(min_val, min_val)) { | ||
| *min_ind = start; | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| // If all values are NaN, return 0 | ||
| if (start == n) { | ||
| *min_ind = 0; | ||
| return 0; | ||
| } | ||
|
|
||
| // Find minimum | ||
| for (npy_intp i = start + 1; i < n; i++) { | ||
| Sleef_quad val = *(Sleef_quad *)(data + i * elsize); | ||
|
|
||
| // Skip NaN values | ||
| if (Sleef_iunordq1(val, val)) { | ||
| continue; | ||
| } | ||
|
|
||
| if (Sleef_icmpltq1(val, min_val)) { | ||
| min_val = val; | ||
| *min_ind = i; | ||
| } | ||
| } | ||
SwayamInSync marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
| else { | ||
| // Find first non-NaN value as initial min | ||
| npy_intp start = 0; | ||
| long double min_val; | ||
| for (start = 0; start < n; start++) { | ||
| min_val = *(long double *)(data + start * elsize); | ||
| if (!isnan(min_val)) { | ||
| *min_ind = start; | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| // If all values are NaN, return 0 | ||
| if (start == n) { | ||
| *min_ind = 0; | ||
| return 0; | ||
| } | ||
|
|
||
| // Find minimum | ||
| for (npy_intp i = start + 1; i < n; i++) { | ||
| long double val = *(long double *)(data + i * elsize); | ||
|
|
||
| // Skip NaN values | ||
| if (isnan(val)) { | ||
| continue; | ||
| } | ||
|
|
||
| if (val < min_val) { | ||
| min_val = val; | ||
| *min_ind = i; | ||
| } | ||
| } | ||
SwayamInSync marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| return 0; | ||
| } | ||
|
|
||
| static PyType_Slot QuadPrecDType_Slots[] = { | ||
| {NPY_DT_ensure_canonical, &ensure_canonical}, | ||
| {NPY_DT_common_instance, &common_instance}, | ||
|
|
@@ -465,6 +633,8 @@ static PyType_Slot QuadPrecDType_Slots[] = { | |
| {NPY_DT_default_descr, &quadprec_default_descr}, | ||
| {NPY_DT_get_constant, &quadprec_get_constant}, | ||
| {NPY_DT_PyArray_ArrFuncs_compare, &quadprec_compare}, | ||
| {NPY_DT_PyArray_ArrFuncs_argmax, &quadprec_argmax}, | ||
| {NPY_DT_PyArray_ArrFuncs_argmin, &quadprec_argmin}, | ||
| {NPY_DT_PyArray_ArrFuncs_fill, &quadprec_fill}, | ||
| {NPY_DT_PyArray_ArrFuncs_scanfunc, &quadprec_scanfunc}, | ||
| {NPY_DT_PyArray_ArrFuncs_fromstr, &quadprec_fromstr}, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That feels wrong (probably does the right thing but ...). What I meant was, after the all-NaN check, load the maxval from the start index before you enter into the loop that finds the total maximum
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Though then again, we're already loading the values ... so this does save one read ... so maybe it's fine or even better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That'll be actually worse
valin the NaN-finding loop (it is must to cast the data to quad pointer)max_val = *(Sleef_quad *)(data + start * elsize)