-
Notifications
You must be signed in to change notification settings - Fork 1
/
daxpy.c
91 lines (82 loc) · 2.81 KB
/
daxpy.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
//
// daxpy - compute y := alpha * x + y
//
//
#include <stdint.h>
#include <stdlib.h>
#include <stdio.h>
#include <velintrin.h>
#define VLEN (256)
void daxpy(int n, double alpha, double *x, int incx,
double *y, int incy)
{
const double *xp = (const double *)x;
double *yp = y;
uint64_t incx_b = sizeof(double) * incx;
uint64_t incy_b = sizeof(double) * incy;
for (size_t i = 0; i < n; i += VLEN) {
const int vl = (n-i < VLEN ? n-i : VLEN);
__vr xv = _vel_vld_vssl(incx_b, xp+i, vl);
__vr yv = _vel_vld_vssl(incy_b, yp+i, vl);
__vr yr = _vel_vfmadd_vvsvl(yv, alpha, xv, vl);
_vel_vst_vssl(yr, incy_b, yp+i, vl);
}
}
//
// unrolled version of daxpy
//
void daxpy_unr(int n, double alpha, double *x, int incx,
double *y, int incy)
{
const double *xp = (const double *)x;
double *yp = y;
uint64_t incx_b = sizeof(double) * incx;
uint64_t incy_b = sizeof(double) * incy;
const size_t vlen1 = 1 * VLEN;
const size_t vlen2 = 2 * VLEN;
const size_t vlen3 = 3 * VLEN;
size_t i = 0;
if (n >= 4*VLEN) {
const int vl = VLEN;
for (; i + 4*VLEN <= n; i += 4*VLEN) {
__vr xv1 = _vel_vld_vssl(incx_b, xp+i , vl);
__vr xv2 = _vel_vld_vssl(incx_b, xp+i+vlen1, vl);
__vr xv3 = _vel_vld_vssl(incx_b, xp+i+vlen2, vl);
__vr xv4 = _vel_vld_vssl(incx_b, xp+i+vlen3, vl);
__vr yv1 = _vel_vld_vssl(incy_b, yp+i , vl);
__vr yv2 = _vel_vld_vssl(incy_b, yp+i+vlen1, vl);
__vr yv3 = _vel_vld_vssl(incy_b, yp+i+vlen2, vl);
__vr yv4 = _vel_vld_vssl(incy_b, yp+i+vlen3, vl);
__vr yr1 = _vel_vfmadd_vvsvl(yv1, alpha, xv1, vl);
__vr yr2 = _vel_vfmadd_vvsvl(yv2, alpha, xv2, vl);
__vr yr3 = _vel_vfmadd_vvsvl(yv3, alpha, xv3, vl);
__vr yr4 = _vel_vfmadd_vvsvl(yv4, alpha, xv4, vl);
_vel_vst_vssl(yr1, incy_b, yp+i , vl);
_vel_vst_vssl(yr2, incy_b, yp+i+vlen1, vl);
_vel_vst_vssl(yr3, incy_b, yp+i+vlen2, vl);
_vel_vst_vssl(yr4, incy_b, yp+i+vlen3, vl);
}
}
if (n - i >= 2*VLEN) {
const int vl = VLEN;
for (; i + 2*VLEN <= n; i += 2*VLEN) {
__vr xv1 = _vel_vld_vssl(incx_b, xp+i , vl);
__vr xv2 = _vel_vld_vssl(incx_b, xp+i+vlen1, vl);
__vr yv1 = _vel_vld_vssl(incy_b, yp+i , vl);
__vr yv2 = _vel_vld_vssl(incy_b, yp+i+vlen1, vl);
__vr yr1 = _vel_vfmadd_vvsvl(yv1, alpha, xv1, vl);
__vr yr2 = _vel_vfmadd_vvsvl(yv2, alpha, xv2, vl);
_vel_vst_vssl(yr1, incy_b, yp+i , vl);
_vel_vst_vssl(yr2, incy_b, yp+i+vlen1, vl);
}
}
if (n - i > 0) {
for (; i < n; i += VLEN) {
const int vl = (n-i < VLEN ? n-i : VLEN);
__vr xv = _vel_vld_vssl(incx_b, xp+i, vl);
__vr yv = _vel_vld_vssl(incy_b, yp+i, vl);
__vr yr = _vel_vfmadd_vvsvl(yv, alpha, xv, vl);
_vel_vst_vssl(yr, incy_b, yp+i, vl);
}
}
}