Skip to content

Commit 8b3aea1

Browse files
authored
Merge pull request #99 from arrayfire/features_3.4
Features 3.4
2 parents 135b6d5 + 3637f58 commit 8b3aea1

22 files changed

+1225
-177
lines changed

arrayfire/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@
7272
from .index import *
7373
from .interop import *
7474
from .timer import *
75+
from .random import *
76+
from .sparse import *
7577

7678
# do not export default modules as part of arrayfire
7779
del ct

arrayfire/algorithm.py

+65
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,71 @@ def accum(a, dim=0):
299299
"""
300300
return _parallel_dim(a, dim, backend.get().af_accum)
301301

302+
def scan(a, dim=0, op=BINARYOP.ADD, inclusive_scan=True):
303+
"""
304+
Generalized scan of an array.
305+
306+
Parameters
307+
----------
308+
a : af.Array
309+
Multi dimensional arrayfire array.
310+
311+
dim : optional: int. default: 0
312+
Dimension along which the scan is performed.
313+
314+
op : optional: af.BINARYOP. default: af.BINARYOP.ADD.
315+
Binary option the scan algorithm uses. Can be one of:
316+
- af.BINARYOP.ADD
317+
- af.BINARYOP.MUL
318+
- af.BINARYOP.MIN
319+
- af.BINARYOP.MAX
320+
321+
inclusive_scan: optional: bool. default: True
322+
Specifies if the scan is inclusive
323+
324+
Returns
325+
---------
326+
out : af.Array
327+
- will contain scan of input.
328+
"""
329+
out = Array()
330+
safe_call(backend.get().af_scan(ct.pointer(out.arr), a.arr, dim, op.value, inclusive_scan))
331+
return out
332+
333+
def scan_by_key(key, a, dim=0, op=BINARYOP.ADD, inclusive_scan=True):
334+
"""
335+
Generalized scan by key of an array.
336+
337+
Parameters
338+
----------
339+
key : af.Array
340+
key array.
341+
342+
a : af.Array
343+
Multi dimensional arrayfire array.
344+
345+
dim : optional: int. default: 0
346+
Dimension along which the scan is performed.
347+
348+
op : optional: af.BINARYOP. default: af.BINARYOP.ADD.
349+
Binary option the scan algorithm uses. Can be one of:
350+
- af.BINARYOP.ADD
351+
- af.BINARYOP.MUL
352+
- af.BINARYOP.MIN
353+
- af.BINARYOP.MAX
354+
355+
inclusive_scan: optional: bool. default: True
356+
Specifies if the scan is inclusive
357+
358+
Returns
359+
---------
360+
out : af.Array
361+
- will contain scan of input.
362+
"""
363+
out = Array()
364+
safe_call(backend.get().af_scan_by_key(ct.pointer(out.arr), key.arr, a.arr, dim, op.value, inclusive_scan))
365+
return out
366+
302367
def where(a):
303368
"""
304369
Find the indices of non zero elements

arrayfire/arith.py

+38
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,44 @@ def maxof(lhs, rhs):
126126
"""
127127
return _arith_binary_func(lhs, rhs, backend.get().af_maxof)
128128

129+
def clamp(val, low, high):
130+
"""
131+
Clamp the input value between low and high
132+
133+
134+
Parameters
135+
----------
136+
val : af.Array
137+
Multi dimensional arrayfire array to be clamped.
138+
139+
low : af.Array or scalar
140+
Multi dimensional arrayfire array or a scalar number denoting the lower value(s).
141+
142+
high : af.Array or scalar
143+
Multi dimensional arrayfire array or a scalar number denoting the higher value(s).
144+
"""
145+
out = Array()
146+
147+
is_low_array = isinstance(low, Array)
148+
is_high_array = isinstance(high, Array)
149+
150+
vdims = dim4_to_tuple(val.dims())
151+
vty = val.type()
152+
153+
if not is_low_array:
154+
low_arr = constant_array(low, vdims[0], vdims[1], vdims[2], vdims[3], vty)
155+
else:
156+
low_arr = low.arr
157+
158+
if not is_high_array:
159+
high_arr = constant_array(high, vdims[0], vdims[1], vdims[2], vdims[3], vty)
160+
else:
161+
high_arr = high.arr
162+
163+
safe_call(backend.get().af_clamp(ct.pointer(out.arr), val.arr, low_arr, high_arr, _bcast_var.get()))
164+
165+
return out
166+
129167
def rem(lhs, rhs):
130168
"""
131169
Find the remainder.

arrayfire/array.py

+8
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,14 @@ def is_vector(self):
667667
safe_call(backend.get().af_is_vector(ct.pointer(res), self.arr))
668668
return res.value
669669

670+
def is_sparse(self):
671+
"""
672+
Check if the array is a sparse matrix.
673+
"""
674+
res = ct.c_bool(False)
675+
safe_call(backend.get().af_is_sparse(ct.pointer(res), self.arr))
676+
return res.value
677+
670678
def is_complex(self):
671679
"""
672680
Check if the array is of complex type.

arrayfire/data.py

+1-99
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .array import *
1717
from .util import *
1818
from .util import _is_number
19+
from .random import randu, randn, set_seed, get_seed
1920

2021
def constant(val, d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
2122
"""
@@ -186,105 +187,6 @@ def iota(d0, d1=None, d2=None, d3=None, dim=-1, tile_dims=None, dtype=Dtype.f32)
186187
4, ct.pointer(tdims), dtype.value))
187188
return out
188189

189-
def randu(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
190-
"""
191-
Create a multi dimensional array containing values from a uniform distribution.
192-
193-
Parameters
194-
----------
195-
d0 : int.
196-
Length of first dimension.
197-
198-
d1 : optional: int. default: None.
199-
Length of second dimension.
200-
201-
d2 : optional: int. default: None.
202-
Length of third dimension.
203-
204-
d3 : optional: int. default: None.
205-
Length of fourth dimension.
206-
207-
dtype : optional: af.Dtype. default: af.Dtype.f32.
208-
Data type of the array.
209-
210-
Returns
211-
-------
212-
213-
out : af.Array
214-
Multi dimensional array whose elements are sampled uniformly between [0, 1].
215-
- If d1 is None, `out` is 1D of size (d0,).
216-
- If d1 is not None and d2 is None, `out` is 2D of size (d0, d1).
217-
- If d1 and d2 are not None and d3 is None, `out` is 3D of size (d0, d1, d2).
218-
- If d1, d2, d3 are all not None, `out` is 4D of size (d0, d1, d2, d3).
219-
"""
220-
out = Array()
221-
dims = dim4(d0, d1, d2, d3)
222-
223-
safe_call(backend.get().af_randu(ct.pointer(out.arr), 4, ct.pointer(dims), dtype.value))
224-
return out
225-
226-
def randn(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
227-
"""
228-
Create a multi dimensional array containing values from a normal distribution.
229-
230-
Parameters
231-
----------
232-
d0 : int.
233-
Length of first dimension.
234-
235-
d1 : optional: int. default: None.
236-
Length of second dimension.
237-
238-
d2 : optional: int. default: None.
239-
Length of third dimension.
240-
241-
d3 : optional: int. default: None.
242-
Length of fourth dimension.
243-
244-
dtype : optional: af.Dtype. default: af.Dtype.f32.
245-
Data type of the array.
246-
247-
Returns
248-
-------
249-
250-
out : af.Array
251-
Multi dimensional array whose elements are sampled from a normal distribution with mean 0 and sigma of 1.
252-
- If d1 is None, `out` is 1D of size (d0,).
253-
- If d1 is not None and d2 is None, `out` is 2D of size (d0, d1).
254-
- If d1 and d2 are not None and d3 is None, `out` is 3D of size (d0, d1, d2).
255-
- If d1, d2, d3 are all not None, `out` is 4D of size (d0, d1, d2, d3).
256-
"""
257-
258-
out = Array()
259-
dims = dim4(d0, d1, d2, d3)
260-
261-
safe_call(backend.get().af_randn(ct.pointer(out.arr), 4, ct.pointer(dims), dtype.value))
262-
return out
263-
264-
def set_seed(seed=0):
265-
"""
266-
Set the seed for the random number generator.
267-
268-
Parameters
269-
----------
270-
seed: int.
271-
Seed for the random number generator
272-
"""
273-
safe_call(backend.get().af_set_seed(ct.c_ulonglong(seed)))
274-
275-
def get_seed():
276-
"""
277-
Get the seed for the random number generator.
278-
279-
Returns
280-
----------
281-
seed: int.
282-
Seed for the random number generator
283-
"""
284-
seed = ct.c_ulonglong(0)
285-
safe_call(backend.get().af_get_seed(ct.pointer(seed)))
286-
return seed.value
287-
288190
def identity(d0, d1, d2=None, d3=None, dtype=Dtype.f32):
289191
"""
290192
Create an identity matrix or batch of identity matrices.

arrayfire/device.py

+90-10
Original file line numberDiff line numberDiff line change
@@ -163,24 +163,87 @@ def sync(device=None):
163163
safe_call(backend.get().af_sync(dev))
164164

165165
def __eval(*args):
166-
for A in args:
167-
if isinstance(A, tuple):
168-
__eval(*A)
169-
if isinstance(A, list):
170-
__eval(*A)
171-
if isinstance(A, Array):
172-
safe_call(backend.get().af_eval(A.arr))
166+
nargs = len(args)
167+
if (nargs == 1):
168+
safe_call(backend.get().af_eval(args[0].arr))
169+
else:
170+
c_void_p_n = ct.c_void_p * nargs
171+
arrs = c_void_p_n()
172+
for n in range(nargs):
173+
arrs[n] = args[n].arr
174+
safe_call(backend.get().af_eval_multiple(ct.c_int(nargs), ct.pointer(arrs)))
175+
return
173176

174177
def eval(*args):
175178
"""
176-
Evaluate the input
179+
Evaluate one or more inputs together
177180
178181
Parameters
179182
-----------
180183
args : arguments to be evaluated
184+
185+
Note
186+
-----
187+
188+
All the input arrays to this function should be of the same size.
189+
190+
Examples
191+
--------
192+
193+
>>> a = af.constant(1, 3, 3)
194+
>>> b = af.constant(2, 3, 3)
195+
>>> c = a + b
196+
>>> d = a - b
197+
>>> af.eval(c, d) # A single kernel is launched here
198+
>>> c
199+
arrayfire.Array()
200+
Type: float
201+
[3 3 1 1]
202+
3.0000 3.0000 3.0000
203+
3.0000 3.0000 3.0000
204+
3.0000 3.0000 3.0000
205+
206+
>>> d
207+
arrayfire.Array()
208+
Type: float
209+
[3 3 1 1]
210+
-1.0000 -1.0000 -1.0000
211+
-1.0000 -1.0000 -1.0000
212+
-1.0000 -1.0000 -1.0000
213+
"""
214+
for arg in args:
215+
if not isinstance(arg, Array):
216+
raise RuntimeError("All inputs to eval must be of type arrayfire.Array")
217+
218+
__eval(*args)
219+
220+
def set_manual_eval_flag(flag):
221+
"""
222+
Tells the backend JIT engine to disable heuristics for determining when to evaluate a JIT tree.
223+
224+
Parameters
225+
----------
226+
227+
flag : optional: bool.
228+
- Specifies if the heuristic evaluation of the JIT tree needs to be disabled.
229+
230+
Note
231+
----
232+
This does not affect the evaluation that occurs when a non JIT function forces the evaluation.
181233
"""
234+
safe_call(backend.get().af_set_manual_eval_flag(flag))
182235

183-
__eval(args)
236+
def get_manual_eval_flag():
237+
"""
238+
Query the backend JIT engine to see if the user disabled heuristic evaluation of the JIT tree.
239+
240+
Note
241+
----
242+
This does not affect the evaluation that occurs when a non JIT function forces the evaluation.
243+
"""
244+
res = ct.c_bool(False)
245+
safe_call(backend.get().af_get_manual_eval_flag(ct.pointer(res)))
246+
return res.value
184247

185248
def device_mem_info():
186249
"""
@@ -258,10 +321,27 @@ def lock_array(a):
258321
259322
Note
260323
-----
261-
- The device pointer of `a` is not freed by memory manager until `unlock_device_ptr()` is called.
324+
- The device pointer of `a` is not freed by memory manager until `unlock_array()` is called.
262325
"""
263326
safe_call(backend.get().af_lock_array(a.arr))
264327

328+
def is_locked_array(a):
329+
"""
330+
Check if the input array is locked by the user.
331+
332+
Parameters
333+
----------
334+
a: af.Array
335+
- A multi dimensional arrayfire array.
336+
337+
Returns
338+
-----------
339+
A bool specifying if the input array is locked.
340+
"""
341+
res = ct.c_bool(False)
342+
safe_call(backend.get().af_is_locked_array(ct.pointer(res), a.arr))
343+
return res.value
344+
265345
def unlock_device_ptr(a):
266346
"""
267347
This functions is deprecated. Please use unlock_array instead.

arrayfire/features.py

+2
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
# The complete license agreement can be obtained at:
77
# http://arrayfire.com/licenses/BSD-3-Clause
88
########################################################
9+
910
"""
1011
Features class used for Computer Vision algorithms.
1112
"""
13+
1214
from .library import *
1315
from .array import *
1416
import numbers

0 commit comments

Comments
 (0)