-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathintrin.py
149 lines (118 loc) · 5.46 KB
/
intrin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def load_int(to, address, const=True):
if const:
return f"const __m256i {to} = _mm256_loadu_si256({address});"
else:
return f"__m256i {to} = _mm256_loadu_si256({address});"
def load_fp(to, address, const=True):
if const:
return f"const __m256 {to} = _mm256_loadu_ps({address});"
else:
return f"__m256 {to} = _mm256_loadu_ps({address});"
# to = a * b + c
def vfma(to, a, b, c):
return f"__m256 {to} = _mm256_fmadd_ps({a}, {b}, {c});"
def vsrli(to, a, b):
return f"const __m256i {to} = _mm256_srli_epi32({a}, {b});"
def vand(to, a, b):
return f"const __m256i {to} = _mm256_and_si256({a}, {b});"
def vbroadcast_fp(to, a):
return f"const __m256 {to} = _mm256_set1_ps({a});"
def vbroadcast_int32(to, a):
return f"__m256i {to} = _mm256_set1_epi32({a});"
def vsetzero(to):
return f"__m256 {to} = _mm256_setzero_ps();"
def vcvtepi32_ps(to, a):
return f"const __m256 {to} = _mm256_cvtepi32_ps({a});"
def _256extractf128_ps(to, a, imm):
return f"const __m128 {to} = _mm256_extractf128_ps({a}, {imm});"
def _256castps256_ps128(to, a):
return f"const __m128 {to} = _mm256_castps256_ps128({a});"
def _add_ps(to, a, b):
return f"const __m128 {to} = _mm_add_ps({a}, {b});"
def _movehl_ps(to, a, b):
return f"const __m128 {to} = _mm_movehl_ps({a}, {b});"
def _shuffle_ps(to, a, b, imm):
return f"const __m128 {to} = _mm_shuffle_ps({a}, {b}, {imm});"
def _cvtss_f32(to, a):
return f"const float {to} = _mm_cvtss_f32({a});"
def _reduce8_acc(a, b, c, d, e, f, g, h):
res = ""
res += _256extractf128_ps("hi_quad0", a, 1)
res += _256extractf128_ps("hi_quad1", b, 1)
res += _256extractf128_ps("hi_quad2", c, 1)
res += _256extractf128_ps("hi_quad3", d, 1)
res += _256extractf128_ps("hi_quad4", e, 1)
res += _256extractf128_ps("hi_quad5", f, 1)
res += _256extractf128_ps("hi_quad6", g, 1)
res += _256extractf128_ps("hi_quad7", h, 1)
res += _256castps256_ps128("lo_quad0", a)
res += _256castps256_ps128("lo_quad1", b)
res += _256castps256_ps128("lo_quad2", c)
res += _256castps256_ps128("lo_quad3", d)
res += _256castps256_ps128("lo_quad4", e)
res += _256castps256_ps128("lo_quad5", f)
res += _256castps256_ps128("lo_quad6", g)
res += _256castps256_ps128("lo_quad7", h)
res += _add_ps("sum_quad0", "lo_quad0", "hi_quad0")
res += _add_ps("sum_quad1", "lo_quad1", "hi_quad1")
res += _add_ps("sum_quad2", "lo_quad2", "hi_quad2")
res += _add_ps("sum_quad3", "lo_quad3", "hi_quad3")
res += _add_ps("sum_quad4", "lo_quad4", "hi_quad4")
res += _add_ps("sum_quad5", "lo_quad5", "hi_quad5")
res += _add_ps("sum_quad6", "lo_quad6", "hi_quad6")
res += _add_ps("sum_quad7", "lo_quad7", "hi_quad7")
res += _movehl_ps("hi_dual0", "sum_quad0", "sum_quad0")
res += _movehl_ps("hi_dual1", "sum_quad1", "sum_quad1")
res += _movehl_ps("hi_dual2", "sum_quad2", "sum_quad2")
res += _movehl_ps("hi_dual3", "sum_quad3", "sum_quad3")
res += _movehl_ps("hi_dual4", "sum_quad4", "sum_quad4")
res += _movehl_ps("hi_dual5", "sum_quad5", "sum_quad5")
res += _movehl_ps("hi_dual6", "sum_quad6", "sum_quad6")
res += _movehl_ps("hi_dual7", "sum_quad7", "sum_quad7")
res += _add_ps("sum_dual0", "sum_quad0", "hi_dual0")
res += _add_ps("sum_dual1", "sum_quad1", "hi_dual1")
res += _add_ps("sum_dual2", "sum_quad2", "hi_dual2")
res += _add_ps("sum_dual3", "sum_quad3", "hi_dual3")
res += _add_ps("sum_dual4", "sum_quad4", "hi_dual4")
res += _add_ps("sum_dual5", "sum_quad5", "hi_dual5")
res += _add_ps("sum_dual6", "sum_quad6", "hi_dual6")
res += _add_ps("sum_dual7", "sum_quad7", "hi_dual7")
res += _shuffle_ps("hi0", "sum_dual0", "sum_dual0", 0x1)
res += _shuffle_ps("hi1", "sum_dual1", "sum_dual1", 0x1)
res += _shuffle_ps("hi2", "sum_dual2", "sum_dual2", 0x1)
res += _shuffle_ps("hi3", "sum_dual3", "sum_dual3", 0x1)
res += _shuffle_ps("hi4", "sum_dual4", "sum_dual4", 0x1)
res += _shuffle_ps("hi5", "sum_dual5", "sum_dual5", 0x1)
res += _shuffle_ps("hi6", "sum_dual6", "sum_dual6", 0x1)
res += _shuffle_ps("hi7", "sum_dual7", "sum_dual7", 0x1)
res += _add_ps("sum0", "sum_dual0", "hi0")
res += _add_ps("sum1", "sum_dual1", "hi1")
res += _add_ps("sum2", "sum_dual2", "hi2")
res += _add_ps("sum3", "sum_dual3", "hi3")
res += _add_ps("sum4", "sum_dual4", "hi4")
res += _add_ps("sum5", "sum_dual5", "hi5")
res += _add_ps("sum6", "sum_dual6", "hi6")
res += _add_ps("sum7", "sum_dual7", "hi7")
res += _cvtss_f32(f"f{a}", "sum0")
res += _cvtss_f32(f"f{b}", "sum1")
res += _cvtss_f32(f"f{c}", "sum2")
res += _cvtss_f32(f"f{d}", "sum3")
res += _cvtss_f32(f"f{e}", "sum4")
res += _cvtss_f32(f"f{f}", "sum5")
res += _cvtss_f32(f"f{g}", "sum6")
res += _cvtss_f32(f"f{h}", "sum7")
return res
acc_idx = 0
def _reduce_add(a):
global acc_idx
res = ""
res += _256extractf128_ps(f"hi_quad{acc_idx}", a, 1)
res += _256castps256_ps128(f"lo_quad{acc_idx}", a)
res += _add_ps(f"sum_quad{acc_idx}", f"lo_quad{acc_idx}", f"hi_quad{acc_idx}")
res += _movehl_ps(f"hi_dual{acc_idx}", f"sum_quad{acc_idx}", f"sum_quad{acc_idx}")
res += _add_ps(f"sum_dual{acc_idx}", f"sum_quad{acc_idx}", f"hi_dual{acc_idx}")
res += _shuffle_ps(f"hi{acc_idx}", f"sum_dual{acc_idx}", f"sum_dual{acc_idx}", 0x1)
res += _add_ps(f"sum{acc_idx}", f"sum_dual{acc_idx}", f"hi{acc_idx}")
res += _cvtss_f32(f"f{a}", f"sum{acc_idx}")
acc_idx += 1
return res