Skip to content

Commit ad77595

Browse files
committed
Fix flake8 errors. Minor refactoring and bugfixes
1 parent 28a0f61 commit ad77595

18 files changed

+493
-546
lines changed

arrayfire/algorithm.py

+30-17
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#######################################################
2-
# Copyright (c) 2019, ArrayFire
2+
# Copyright (c) 2020, ArrayFire
33
# All rights reserved.
44
#
55
# This file is distributed under 3-clause BSD license.
@@ -14,11 +14,13 @@
1414
from .array import Array
1515
from .library import backend, safe_call, BINARYOP, c_bool_t, c_double_t, c_int_t, c_pointer, c_uint_t
1616

17+
1718
def _parallel_dim(a, dim, c_func):
1819
out = Array()
1920
safe_call(c_func(c_pointer(out.arr), a.arr, c_int_t(dim)))
2021
return out
2122

23+
2224
def _reduce_all(a, c_func):
2325
real = c_double_t(0)
2426
imag = c_double_t(0)
@@ -29,11 +31,13 @@ def _reduce_all(a, c_func):
2931
imag = imag.value
3032
return real if imag == 0 else real + imag * 1j
3133

34+
3235
def _nan_parallel_dim(a, dim, c_func, nan_val):
3336
out = Array()
3437
safe_call(c_func(c_pointer(out.arr), a.arr, c_int_t(dim), c_double_t(nan_val)))
3538
return out
3639

40+
3741
def _nan_reduce_all(a, c_func, nan_val):
3842
real = c_double_t(0)
3943
imag = c_double_t(0)
@@ -44,6 +48,7 @@ def _nan_reduce_all(a, c_func, nan_val):
4448
imag = imag.value
4549
return real if imag == 0 else real + imag * 1j
4650

51+
4752
def _FNSD(dim, dims):
4853
if dim >= 0:
4954
return int(dim)
@@ -55,20 +60,26 @@ def _FNSD(dim, dims):
5560
break
5661
return int(fnsd)
5762

63+
5864
def _rbk_dim(keys, vals, dim, c_func):
5965
keys_out = Array()
6066
vals_out = Array()
6167
rdim = _FNSD(dim, vals.dims())
6268
safe_call(c_func(c_pointer(keys_out.arr), c_pointer(vals_out.arr), keys.arr, vals.arr, c_int_t(rdim)))
6369
return keys_out, vals_out
6470

71+
6572
def _nan_rbk_dim(a, dim, c_func, nan_val):
6673
keys_out = Array()
6774
vals_out = Array()
75+
# FIXME: vals is undefined
6876
rdim = _FNSD(dim, vals.dims())
69-
safe_call(c_func(c_pointer(keys_out.arr), c_pointer(vals_out.arr), keys.arr, vals.arr, c_int_t(rdim), c_double_t(nan_val)))
77+
# FIXME: keys is undefined
78+
safe_call(c_func(
79+
c_pointer(keys_out.arr), c_pointer(vals_out.arr), keys.arr, vals.arr, c_int_t(rdim), c_double_t(nan_val)))
7080
return keys_out, vals_out
7181

82+
7283
def sum(a, dim=None, nan_val=None):
7384
"""
7485
Calculate the sum of all the elements along a specified dimension.
@@ -88,18 +99,16 @@ def sum(a, dim=None, nan_val=None):
8899
The sum of all elements in `a` along dimension `dim`.
89100
If `dim` is `None`, sum of the entire Array is returned.
90101
"""
91-
if nan_val is not None:
92-
if dim is not None:
102+
if nan_val:
103+
if dim:
93104
return _nan_parallel_dim(a, dim, backend.get().af_sum_nan, nan_val)
94105
return _nan_reduce_all(a, backend.get().af_sum_nan_all, nan_val)
95106

96-
if dim is not None:
107+
if dim:
97108
return _parallel_dim(a, dim, backend.get().af_sum)
98109
return _reduce_all(a, backend.get().af_sum_all)
99110

100111

101-
102-
103112
def sumByKey(keys, vals, dim=-1, nan_val=None):
104113
"""
105114
Calculate the sum of elements along a specified dimension according to a key.
@@ -122,10 +131,10 @@ def sumByKey(keys, vals, dim=-1, nan_val=None):
122131
values: af.Array or scalar number
123132
The sum of all elements in `vals` along dimension `dim` according to keys
124133
"""
125-
if (nan_val is not None):
134+
if nan_val:
126135
return _nan_rbk_dim(keys, vals, dim, backend.get().af_sum_by_key_nan, nan_val)
127-
else:
128-
return _rbk_dim(keys, vals, dim, backend.get().af_sum_by_key)
136+
return _rbk_dim(keys, vals, dim, backend.get().af_sum_by_key)
137+
129138

130139
def product(a, dim=None, nan_val=None):
131140
"""
@@ -178,10 +187,10 @@ def productByKey(keys, vals, dim=-1, nan_val=None):
178187
values: af.Array or scalar number
179188
The product of all elements in `vals` along dimension `dim` according to keys
180189
"""
181-
if (nan_val is not None):
190+
if nan_val is not None:
182191
return _nan_rbk_dim(keys, vals, dim, backend.get().af_product_by_key_nan, nan_val)
183-
else:
184-
return _rbk_dim(keys, vals, dim, backend.get().af_product_by_key)
192+
return _rbk_dim(keys, vals, dim, backend.get().af_product_by_key)
193+
185194

186195
def min(a, dim=None):
187196
"""
@@ -227,6 +236,7 @@ def minByKey(keys, vals, dim=-1):
227236
"""
228237
return _rbk_dim(keys, vals, dim, backend.get().af_min_by_key)
229238

239+
230240
def max(a, dim=None):
231241
"""
232242
Find the maximum value of all the elements along a specified dimension.
@@ -271,6 +281,7 @@ def maxByKey(keys, vals, dim=-1):
271281
"""
272282
return _rbk_dim(keys, vals, dim, backend.get().af_max_by_key)
273283

284+
274285
def all_true(a, dim=None):
275286
"""
276287
Check if all the elements along a specified dimension are true.
@@ -315,6 +326,7 @@ def allTrueByKey(keys, vals, dim=-1):
315326
"""
316327
return _rbk_dim(keys, vals, dim, backend.get().af_all_true_by_key)
317328

329+
318330
def any_true(a, dim=None):
319331
"""
320332
Check if any the elements along a specified dimension are true.
@@ -334,8 +346,8 @@ def any_true(a, dim=None):
334346
"""
335347
if dim is not None:
336348
return _parallel_dim(a, dim, backend.get().af_any_true)
337-
else:
338-
return _reduce_all(a, backend.get().af_any_true_all)
349+
return _reduce_all(a, backend.get().af_any_true_all)
350+
339351

340352
def anyTrueByKey(keys, vals, dim=-1):
341353
"""
@@ -359,6 +371,7 @@ def anyTrueByKey(keys, vals, dim=-1):
359371
"""
360372
return _rbk_dim(keys, vals, dim, backend.get().af_any_true_by_key)
361373

374+
362375
def count(a, dim=None):
363376
"""
364377
Count the number of non zero elements in an array along a specified dimension.
@@ -378,8 +391,7 @@ def count(a, dim=None):
378391
"""
379392
if dim is not None:
380393
return _parallel_dim(a, dim, backend.get().af_count)
381-
else:
382-
return _reduce_all(a, backend.get().af_count_all)
394+
return _reduce_all(a, backend.get().af_count_all)
383395

384396

385397
def countByKey(keys, vals, dim=-1):
@@ -404,6 +416,7 @@ def countByKey(keys, vals, dim=-1):
404416
"""
405417
return _rbk_dim(keys, vals, dim, backend.get().af_count_by_key)
406418

419+
407420
def imin(a, dim=None):
408421
"""
409422
Find the value and location of the minimum value along a specified dimension

arrayfire/arith.py

+4-10
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def cast(a, dtype):
7777
out : af.Array
7878
array containing the values from `a` after converting to `dtype`.
7979
"""
80-
out=Array()
80+
out = Array()
8181
safe_call(backend.get().af_cast(c_pointer(out.arr), a.arr, dtype.value))
8282
return out
8383

@@ -156,15 +156,8 @@ def clamp(val, low, high):
156156
vdims = dim4_to_tuple(val.dims())
157157
vty = val.type()
158158

159-
if not is_low_array:
160-
low_arr = constant_array(low, vdims[0], vdims[1], vdims[2], vdims[3], vty)
161-
else:
162-
low_arr = low.arr
163-
164-
if not is_high_array:
165-
high_arr = constant_array(high, vdims[0], vdims[1], vdims[2], vdims[3], vty)
166-
else:
167-
high_arr = high.arr
159+
low_arr = low.arr if is_low_array else constant_array(low, vdims[0], vdims[1], vdims[2], vdims[3], vty)
160+
high_arr = high.arr if is_high_array else constant_array(high, vdims[0], vdims[1], vdims[2], vdims[3], vty)
168161

169162
safe_call(backend.get().af_clamp(c_pointer(out.arr), val.arr, low_arr, high_arr, _bcast_var.get()))
170163

@@ -1003,6 +996,7 @@ def sqrt(a):
1003996
"""
1004997
return _arith_unary_func(a, backend.get().af_sqrt)
1005998

999+
10061000
def rsqrt(a):
10071001
"""
10081002
Reciprocal or inverse square root of each element in the array.

0 commit comments

Comments
 (0)