@@ -423,6 +423,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
423423 int Sz = PyArray_STRIDES(%(z)s)[0] / elemsize;
424424 int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;
425425
426+ dtype_%(A)s* A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
426427 dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
427428 dtype_%(z)s* z_data = (dtype_%(z)s*) PyArray_DATA(%(z)s);
428429 // gemv expects pointers to the beginning of memory arrays,
@@ -435,17 +436,25 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
435436
436437 if (NA0 * NA1)
437438 {
438- // If A is neither C- nor F-contiguous, we make a copy.
439- // TODO:
440- // - if one stride is equal to "- elemsize", we can still call
441- // gemv on reversed matrix and vectors
442- // - if the copy is too long, maybe call vector/vector dot on
443- // each row instead
444- if ((PyArray_STRIDES(%(A)s)[0] < 0)
445- || (PyArray_STRIDES(%(A)s)[1] < 0)
446- || ((PyArray_STRIDES(%(A)s)[0] != elemsize)
447- && (PyArray_STRIDES(%(A)s)[1] != elemsize)))
439+ if ( ((SA0 < 0) || (SA1 < 0)) && (abs(SA0) == 1 || (abs(SA1) == 1)) )
448440 {
441+ // We can treat the array A as C-or F-contiguous by changing the order of iteration
442+ if (SA0 < 0){
443+ A_data += (NA0 -1) * SA0; // Jump to first row
444+ SA0 = -SA0; // Iterate over rows in reverse
445+ Sz = -Sz; // Iterate over y in reverse
446+ }
447+ if (SA1 < 0){
448+ A_data += (NA1 -1) * SA1; // Jump to first column
449+ SA1 = -SA1; // Iterate over columns in reverse
450+ Sx = -Sx; // Iterate over x in reverse
451+ }
452+ } else if ((SA0 < 0) || (SA1 < 0) || ((SA0 != 1) && (SA1 != 1)))
453+ {
454+ // Array isn't contiguous, we have to make a copy
455+ // - if the copy is too long, maybe call vector/vector dot on
456+ // each row instead
457+ // printf("GEMV: Making a copy SA0=%%d, SA1=%%d\\ n", SA0, SA1);
449458 npy_intp dims[2];
450459 dims[0] = NA0;
451460 dims[1] = NA1;
@@ -458,16 +467,17 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
458467 %(A)s = A_copy;
459468 SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : (NA1 + 1);
460469 SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : (NA0 + 1);
470+ A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
461471 }
462472
463- if (PyArray_STRIDES(%(A)s)[0] == elemsize )
473+ if (SA0 == 1 )
464474 {
465475 if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
466476 {
467477 float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
468478 sgemv_(&NOTRANS, &NA0, &NA1,
469479 &alpha,
470- (float*)(PyArray_DATA(%(A)s) ), &SA1,
480+ (float*)(A_data ), &SA1,
471481 (float*)x_data, &Sx,
472482 &fbeta,
473483 (float*)z_data, &Sz);
@@ -477,7 +487,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
477487 double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
478488 dgemv_(&NOTRANS, &NA0, &NA1,
479489 &alpha,
480- (double*)(PyArray_DATA(%(A)s) ), &SA1,
490+ (double*)(A_data ), &SA1,
481491 (double*)x_data, &Sx,
482492 &dbeta,
483493 (double*)z_data, &Sz);
@@ -489,7 +499,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
489499 %(fail)s
490500 }
491501 }
492- else if (PyArray_STRIDES(%(A)s)[1] == elemsize )
502+ else if (SA1 == 1 )
493503 {
494504 if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
495505 {
@@ -506,14 +516,14 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
506516 z_data[0] = 0.f;
507517 }
508518 z_data[0] += alpha*sdot_(&NA1,
509- (float*)(PyArray_DATA(%(A)s) ), &SA1,
519+ (float*)(A_data ), &SA1,
510520 (float*)x_data, &Sx);
511521 }
512522 else
513523 {
514524 sgemv_(&TRANS, &NA1, &NA0,
515525 &alpha,
516- (float*)(PyArray_DATA(%(A)s) ), &SA0,
526+ (float*)(A_data ), &SA0,
517527 (float*)x_data, &Sx,
518528 &fbeta,
519529 (float*)z_data, &Sz);
@@ -534,14 +544,14 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
534544 z_data[0] = 0.;
535545 }
536546 z_data[0] += alpha*ddot_(&NA1,
537- (double*)(PyArray_DATA(%(A)s) ), &SA1,
547+ (double*)(A_data ), &SA1,
538548 (double*)x_data, &Sx);
539549 }
540550 else
541551 {
542552 dgemv_(&TRANS, &NA1, &NA0,
543553 &alpha,
544- (double*)(PyArray_DATA(%(A)s) ), &SA0,
554+ (double*)(A_data ), &SA0,
545555 (double*)x_data, &Sx,
546556 &dbeta,
547557 (double*)z_data, &Sz);
@@ -603,7 +613,7 @@ def c_code(self, node, name, inp, out, sub):
603613 return code
604614
605615 def c_code_cache_version (self ):
606- return (14 , blas_header_version (), check_force_gemv_init ())
616+ return (15 , blas_header_version (), check_force_gemv_init ())
607617
608618
609619cgemv_inplace = CGemv (inplace = True )
0 commit comments