11#######################################################
2- # Copyright (c) 2019 , ArrayFire
2+ # Copyright (c) 2020 , ArrayFire
33# All rights reserved.
44#
55# This file is distributed under 3-clause BSD license.
1414from .array import Array
1515from .library import backend , safe_call , BINARYOP , c_bool_t , c_double_t , c_int_t , c_pointer , c_uint_t
1616
17+
1718def _parallel_dim (a , dim , c_func ):
1819 out = Array ()
1920 safe_call (c_func (c_pointer (out .arr ), a .arr , c_int_t (dim )))
2021 return out
2122
23+
2224def _reduce_all (a , c_func ):
2325 real = c_double_t (0 )
2426 imag = c_double_t (0 )
@@ -29,11 +31,13 @@ def _reduce_all(a, c_func):
2931 imag = imag .value
3032 return real if imag == 0 else real + imag * 1j
3133
34+
3235def _nan_parallel_dim (a , dim , c_func , nan_val ):
3336 out = Array ()
3437 safe_call (c_func (c_pointer (out .arr ), a .arr , c_int_t (dim ), c_double_t (nan_val )))
3538 return out
3639
40+
3741def _nan_reduce_all (a , c_func , nan_val ):
3842 real = c_double_t (0 )
3943 imag = c_double_t (0 )
@@ -44,6 +48,7 @@ def _nan_reduce_all(a, c_func, nan_val):
4448 imag = imag .value
4549 return real if imag == 0 else real + imag * 1j
4650
51+
4752def _FNSD (dim , dims ):
4853 if dim >= 0 :
4954 return int (dim )
@@ -55,20 +60,26 @@ def _FNSD(dim, dims):
5560 break
5661 return int (fnsd )
5762
63+
5864def _rbk_dim (keys , vals , dim , c_func ):
5965 keys_out = Array ()
6066 vals_out = Array ()
6167 rdim = _FNSD (dim , vals .dims ())
6268 safe_call (c_func (c_pointer (keys_out .arr ), c_pointer (vals_out .arr ), keys .arr , vals .arr , c_int_t (rdim )))
6369 return keys_out , vals_out
6470
71+
6572def _nan_rbk_dim (a , dim , c_func , nan_val ):
6673 keys_out = Array ()
6774 vals_out = Array ()
75+ # FIXME: vals is undefined
6876 rdim = _FNSD (dim , vals .dims ())
69- safe_call (c_func (c_pointer (keys_out .arr ), c_pointer (vals_out .arr ), keys .arr , vals .arr , c_int_t (rdim ), c_double_t (nan_val )))
77+ # FIXME: keys is undefined
78+ safe_call (c_func (
79+ c_pointer (keys_out .arr ), c_pointer (vals_out .arr ), keys .arr , vals .arr , c_int_t (rdim ), c_double_t (nan_val )))
7080 return keys_out , vals_out
7181
82+
7283def sum (a , dim = None , nan_val = None ):
7384 """
7485 Calculate the sum of all the elements along a specified dimension.
@@ -98,8 +109,6 @@ def sum(a, dim=None, nan_val=None):
98109 return _reduce_all (a , backend .get ().af_sum_all )
99110
100111
101-
102-
103112def sumByKey (keys , vals , dim = - 1 , nan_val = None ):
104113 """
105114 Calculate the sum of elements along a specified dimension according to a key.
@@ -122,10 +131,10 @@ def sumByKey(keys, vals, dim=-1, nan_val=None):
122131 values: af.Array or scalar number
123132 The sum of all elements in `vals` along dimension `dim` according to keys
124133 """
125- if ( nan_val is not None ) :
134+ if nan_val is not None :
126135 return _nan_rbk_dim (keys , vals , dim , backend .get ().af_sum_by_key_nan , nan_val )
127- else :
128- return _rbk_dim ( keys , vals , dim , backend . get (). af_sum_by_key )
136+ return _rbk_dim ( keys , vals , dim , backend . get (). af_sum_by_key )
137+
129138
130139def product (a , dim = None , nan_val = None ):
131140 """
@@ -178,10 +187,10 @@ def productByKey(keys, vals, dim=-1, nan_val=None):
178187 values: af.Array or scalar number
179188 The product of all elements in `vals` along dimension `dim` according to keys
180189 """
181- if ( nan_val is not None ) :
190+ if nan_val is not None :
182191 return _nan_rbk_dim (keys , vals , dim , backend .get ().af_product_by_key_nan , nan_val )
183- else :
184- return _rbk_dim ( keys , vals , dim , backend . get (). af_product_by_key )
192+ return _rbk_dim ( keys , vals , dim , backend . get (). af_product_by_key )
193+
185194
186195def min (a , dim = None ):
187196 """
@@ -227,6 +236,7 @@ def minByKey(keys, vals, dim=-1):
227236 """
228237 return _rbk_dim (keys , vals , dim , backend .get ().af_min_by_key )
229238
239+
230240def max (a , dim = None ):
231241 """
232242 Find the maximum value of all the elements along a specified dimension.
@@ -271,6 +281,7 @@ def maxByKey(keys, vals, dim=-1):
271281 """
272282 return _rbk_dim (keys , vals , dim , backend .get ().af_max_by_key )
273283
284+
274285def all_true (a , dim = None ):
275286 """
276287 Check if all the elements along a specified dimension are true.
@@ -315,6 +326,7 @@ def allTrueByKey(keys, vals, dim=-1):
315326 """
316327 return _rbk_dim (keys , vals , dim , backend .get ().af_all_true_by_key )
317328
329+
318330def any_true (a , dim = None ):
319331 """
320332 Check if any the elements along a specified dimension are true.
@@ -334,8 +346,8 @@ def any_true(a, dim=None):
334346 """
335347 if dim is not None :
336348 return _parallel_dim (a , dim , backend .get ().af_any_true )
337- else :
338- return _reduce_all ( a , backend . get (). af_any_true_all )
349+ return _reduce_all ( a , backend . get (). af_any_true_all )
350+
339351
340352def anyTrueByKey (keys , vals , dim = - 1 ):
341353 """
@@ -359,6 +371,7 @@ def anyTrueByKey(keys, vals, dim=-1):
359371 """
360372 return _rbk_dim (keys , vals , dim , backend .get ().af_any_true_by_key )
361373
374+
362375def count (a , dim = None ):
363376 """
364377 Count the number of non zero elements in an array along a specified dimension.
@@ -378,8 +391,7 @@ def count(a, dim=None):
378391 """
379392 if dim is not None :
380393 return _parallel_dim (a , dim , backend .get ().af_count )
381- else :
382- return _reduce_all (a , backend .get ().af_count_all )
394+ return _reduce_all (a , backend .get ().af_count_all )
383395
384396
385397def countByKey (keys , vals , dim = - 1 ):
@@ -404,6 +416,7 @@ def countByKey(keys, vals, dim=-1):
404416 """
405417 return _rbk_dim (keys , vals , dim , backend .get ().af_count_by_key )
406418
419+
407420def imin (a , dim = None ):
408421 """
409422 Find the value and location of the minimum value along a specified dimension
0 commit comments