1
1
#######################################################
2
- # Copyright (c) 2019 , ArrayFire
2
+ # Copyright (c) 2020 , ArrayFire
3
3
# All rights reserved.
4
4
#
5
5
# This file is distributed under 3-clause BSD license.
14
14
from .array import Array
15
15
from .library import backend , safe_call , BINARYOP , c_bool_t , c_double_t , c_int_t , c_pointer , c_uint_t
16
16
17
+
17
18
def _parallel_dim (a , dim , c_func ):
18
19
out = Array ()
19
20
safe_call (c_func (c_pointer (out .arr ), a .arr , c_int_t (dim )))
20
21
return out
21
22
23
+
22
24
def _reduce_all (a , c_func ):
23
25
real = c_double_t (0 )
24
26
imag = c_double_t (0 )
@@ -29,11 +31,13 @@ def _reduce_all(a, c_func):
29
31
imag = imag .value
30
32
return real if imag == 0 else real + imag * 1j
31
33
34
+
32
35
def _nan_parallel_dim (a , dim , c_func , nan_val ):
33
36
out = Array ()
34
37
safe_call (c_func (c_pointer (out .arr ), a .arr , c_int_t (dim ), c_double_t (nan_val )))
35
38
return out
36
39
40
+
37
41
def _nan_reduce_all (a , c_func , nan_val ):
38
42
real = c_double_t (0 )
39
43
imag = c_double_t (0 )
@@ -44,6 +48,7 @@ def _nan_reduce_all(a, c_func, nan_val):
44
48
imag = imag .value
45
49
return real if imag == 0 else real + imag * 1j
46
50
51
+
47
52
def _FNSD (dim , dims ):
48
53
if dim >= 0 :
49
54
return int (dim )
@@ -55,20 +60,26 @@ def _FNSD(dim, dims):
55
60
break
56
61
return int (fnsd )
57
62
63
+
58
64
def _rbk_dim (keys , vals , dim , c_func ):
59
65
keys_out = Array ()
60
66
vals_out = Array ()
61
67
rdim = _FNSD (dim , vals .dims ())
62
68
safe_call (c_func (c_pointer (keys_out .arr ), c_pointer (vals_out .arr ), keys .arr , vals .arr , c_int_t (rdim )))
63
69
return keys_out , vals_out
64
70
71
+
65
72
def _nan_rbk_dim (a , dim , c_func , nan_val ):
66
73
keys_out = Array ()
67
74
vals_out = Array ()
75
+ # FIXME: vals is undefined
68
76
rdim = _FNSD (dim , vals .dims ())
69
- safe_call (c_func (c_pointer (keys_out .arr ), c_pointer (vals_out .arr ), keys .arr , vals .arr , c_int_t (rdim ), c_double_t (nan_val )))
77
+ # FIXME: keys is undefined
78
+ safe_call (c_func (
79
+ c_pointer (keys_out .arr ), c_pointer (vals_out .arr ), keys .arr , vals .arr , c_int_t (rdim ), c_double_t (nan_val )))
70
80
return keys_out , vals_out
71
81
82
+
72
83
def sum (a , dim = None , nan_val = None ):
73
84
"""
74
85
Calculate the sum of all the elements along a specified dimension.
@@ -88,18 +99,16 @@ def sum(a, dim=None, nan_val=None):
88
99
The sum of all elements in `a` along dimension `dim`.
89
100
If `dim` is `None`, sum of the entire Array is returned.
90
101
"""
91
- if nan_val is not None :
92
- if dim is not None :
102
+ if nan_val :
103
+ if dim :
93
104
return _nan_parallel_dim (a , dim , backend .get ().af_sum_nan , nan_val )
94
105
return _nan_reduce_all (a , backend .get ().af_sum_nan_all , nan_val )
95
106
96
- if dim is not None :
107
+ if dim :
97
108
return _parallel_dim (a , dim , backend .get ().af_sum )
98
109
return _reduce_all (a , backend .get ().af_sum_all )
99
110
100
111
101
-
102
-
103
112
def sumByKey (keys , vals , dim = - 1 , nan_val = None ):
104
113
"""
105
114
Calculate the sum of elements along a specified dimension according to a key.
@@ -122,10 +131,10 @@ def sumByKey(keys, vals, dim=-1, nan_val=None):
122
131
values: af.Array or scalar number
123
132
The sum of all elements in `vals` along dimension `dim` according to keys
124
133
"""
125
- if ( nan_val is not None ) :
134
+ if nan_val :
126
135
return _nan_rbk_dim (keys , vals , dim , backend .get ().af_sum_by_key_nan , nan_val )
127
- else :
128
- return _rbk_dim ( keys , vals , dim , backend . get (). af_sum_by_key )
136
+ return _rbk_dim ( keys , vals , dim , backend . get (). af_sum_by_key )
137
+
129
138
130
139
def product (a , dim = None , nan_val = None ):
131
140
"""
@@ -178,10 +187,10 @@ def productByKey(keys, vals, dim=-1, nan_val=None):
178
187
values: af.Array or scalar number
179
188
The product of all elements in `vals` along dimension `dim` according to keys
180
189
"""
181
- if ( nan_val is not None ) :
190
+ if nan_val is not None :
182
191
return _nan_rbk_dim (keys , vals , dim , backend .get ().af_product_by_key_nan , nan_val )
183
- else :
184
- return _rbk_dim ( keys , vals , dim , backend . get (). af_product_by_key )
192
+ return _rbk_dim ( keys , vals , dim , backend . get (). af_product_by_key )
193
+
185
194
186
195
def min (a , dim = None ):
187
196
"""
@@ -227,6 +236,7 @@ def minByKey(keys, vals, dim=-1):
227
236
"""
228
237
return _rbk_dim (keys , vals , dim , backend .get ().af_min_by_key )
229
238
239
+
230
240
def max (a , dim = None ):
231
241
"""
232
242
Find the maximum value of all the elements along a specified dimension.
@@ -271,6 +281,7 @@ def maxByKey(keys, vals, dim=-1):
271
281
"""
272
282
return _rbk_dim (keys , vals , dim , backend .get ().af_max_by_key )
273
283
284
+
274
285
def all_true (a , dim = None ):
275
286
"""
276
287
Check if all the elements along a specified dimension are true.
@@ -315,6 +326,7 @@ def allTrueByKey(keys, vals, dim=-1):
315
326
"""
316
327
return _rbk_dim (keys , vals , dim , backend .get ().af_all_true_by_key )
317
328
329
+
318
330
def any_true (a , dim = None ):
319
331
"""
320
332
Check if any the elements along a specified dimension are true.
@@ -334,8 +346,8 @@ def any_true(a, dim=None):
334
346
"""
335
347
if dim is not None :
336
348
return _parallel_dim (a , dim , backend .get ().af_any_true )
337
- else :
338
- return _reduce_all ( a , backend . get (). af_any_true_all )
349
+ return _reduce_all ( a , backend . get (). af_any_true_all )
350
+
339
351
340
352
def anyTrueByKey (keys , vals , dim = - 1 ):
341
353
"""
@@ -359,6 +371,7 @@ def anyTrueByKey(keys, vals, dim=-1):
359
371
"""
360
372
return _rbk_dim (keys , vals , dim , backend .get ().af_any_true_by_key )
361
373
374
+
362
375
def count (a , dim = None ):
363
376
"""
364
377
Count the number of non zero elements in an array along a specified dimension.
@@ -378,8 +391,7 @@ def count(a, dim=None):
378
391
"""
379
392
if dim is not None :
380
393
return _parallel_dim (a , dim , backend .get ().af_count )
381
- else :
382
- return _reduce_all (a , backend .get ().af_count_all )
394
+ return _reduce_all (a , backend .get ().af_count_all )
383
395
384
396
385
397
def countByKey (keys , vals , dim = - 1 ):
@@ -404,6 +416,7 @@ def countByKey(keys, vals, dim=-1):
404
416
"""
405
417
return _rbk_dim (keys , vals , dim , backend .get ().af_count_by_key )
406
418
419
+
407
420
def imin (a , dim = None ):
408
421
"""
409
422
Find the value and location of the minimum value along a specified dimension
0 commit comments