11
11
Array class and helper functions.
12
12
"""
13
13
14
- from .algorithm import sum , count
15
- from .arith import cast
16
14
import inspect
17
15
import os
18
16
from .library import *
27
25
28
26
_display_dims_limit = None
29
27
30
-
31
28
def set_display_dims_limit (* dims ):
32
29
"""
33
30
Sets the dimension limit after which array's data won't get
@@ -47,7 +44,6 @@ def set_display_dims_limit(*dims):
47
44
global _display_dims_limit
48
45
_display_dims_limit = dims
49
46
50
-
51
47
def get_display_dims_limit ():
52
48
"""
53
49
Gets the dimension limit after which array's data won't get
@@ -71,7 +67,6 @@ def get_display_dims_limit():
71
67
"""
72
68
return _display_dims_limit
73
69
74
-
75
70
def _in_display_dims_limit (dims ):
76
71
if _is_running_in_py_charm :
77
72
return False
@@ -85,7 +80,6 @@ def _in_display_dims_limit(dims):
85
80
return False
86
81
return True
87
82
88
-
89
83
def _create_array (buf , numdims , idims , dtype , is_device ):
90
84
out_arr = c_void_ptr_t (0 )
91
85
c_dims = dim4 (idims [0 ], idims [1 ], idims [2 ], idims [3 ])
@@ -97,7 +91,6 @@ def _create_array(buf, numdims, idims, dtype, is_device):
97
91
numdims , c_pointer (c_dims ), dtype .value ))
98
92
return out_arr
99
93
100
-
101
94
def _create_strided_array (buf , numdims , idims , dtype , is_device , offset , strides ):
102
95
out_arr = c_void_ptr_t (0 )
103
96
c_dims = dim4 (idims [0 ], idims [1 ], idims [2 ], idims [3 ])
@@ -119,15 +112,16 @@ def _create_strided_array(buf, numdims, idims, dtype, is_device, offset, strides
119
112
location .value ))
120
113
return out_arr
121
114
122
-
123
115
def _create_empty_array (numdims , idims , dtype ):
124
116
out_arr = c_void_ptr_t (0 )
117
+
118
+ if numdims == 0 : return out_arr
119
+
125
120
c_dims = dim4 (idims [0 ], idims [1 ], idims [2 ], idims [3 ])
126
121
safe_call (backend .get ().af_create_handle (c_pointer (out_arr ),
127
122
numdims , c_pointer (c_dims ), dtype .value ))
128
123
return out_arr
129
124
130
-
131
125
def constant_array (val , d0 , d1 = None , d2 = None , d3 = None , dtype = Dtype .f32 ):
132
126
"""
133
127
Internal function to create a C array. Should not be used externall.
@@ -182,7 +176,6 @@ def _binary_func(lhs, rhs, c_func):
182
176
183
177
return out
184
178
185
-
186
179
def _binary_funcr (lhs , rhs , c_func ):
187
180
out = Array ()
188
181
other = lhs
@@ -199,10 +192,9 @@ def _binary_funcr(lhs, rhs, c_func):
199
192
200
193
return out
201
194
202
-
203
195
def _ctype_to_lists (ctype_arr , dim , shape , offset = 0 ):
204
196
if (dim == 0 ):
205
- return list (ctype_arr [offset : offset + shape [0 ]])
197
+ return list (ctype_arr [offset : offset + shape [0 ]])
206
198
else :
207
199
dim_len = shape [dim ]
208
200
res = [[]] * dim_len
@@ -211,7 +203,6 @@ def _ctype_to_lists(ctype_arr, dim, shape, offset=0):
211
203
offset += shape [0 ]
212
204
return res
213
205
214
-
215
206
def _slice_to_length (key , dim ):
216
207
tkey = [key .start , key .stop , key .step ]
217
208
@@ -230,7 +221,6 @@ def _slice_to_length(key, dim):
230
221
231
222
return int (((tkey [1 ] - tkey [0 ] - 1 ) / tkey [2 ]) + 1 )
232
223
233
-
234
224
def _get_info (dims , buf_len ):
235
225
elements = 1
236
226
numdims = 0
@@ -260,7 +250,6 @@ def _get_indices(key):
260
250
261
251
return inds
262
252
263
-
264
253
def _get_assign_dims (key , idims ):
265
254
266
255
dims = [1 ]* 4
@@ -307,7 +296,6 @@ def _get_assign_dims(key, idims):
307
296
else :
308
297
raise IndexError ("Invalid type while assigning to arrayfire.array" )
309
298
310
-
311
299
def transpose (a , conj = False ):
312
300
"""
313
301
Perform the transpose on an input.
@@ -330,7 +318,6 @@ def transpose(a, conj=False):
330
318
safe_call (backend .get ().af_transpose (c_pointer (out .arr ), a .arr , conj ))
331
319
return out
332
320
333
-
334
321
def transpose_inplace (a , conj = False ):
335
322
"""
336
323
Perform inplace transpose on an input.
@@ -351,7 +338,6 @@ def transpose_inplace(a, conj=False):
351
338
"""
352
339
safe_call (backend .get ().af_transpose_inplace (a .arr , conj ))
353
340
354
-
355
341
class Array (BaseArray ):
356
342
357
343
"""
@@ -461,8 +447,8 @@ def __init__(self, src=None, dims=None, dtype=None, is_device=False, offset=None
461
447
462
448
super (Array , self ).__init__ ()
463
449
464
- buf = None
465
- buf_len = 0
450
+ buf = None
451
+ buf_len = 0
466
452
467
453
if dtype is not None :
468
454
if isinstance (dtype , str ):
@@ -472,7 +458,7 @@ def __init__(self, src=None, dims=None, dtype=None, is_device=False, offset=None
472
458
else :
473
459
type_char = None
474
460
475
- _type_char = 'f'
461
+ _type_char = 'f'
476
462
477
463
if src is not None :
478
464
@@ -483,12 +469,12 @@ def __init__(self, src=None, dims=None, dtype=None, is_device=False, offset=None
483
469
host = __import__ ("array" )
484
470
485
471
if isinstance (src , host .array ):
486
- buf , buf_len = src .buffer_info ()
472
+ buf ,buf_len = src .buffer_info ()
487
473
_type_char = src .typecode
488
474
numdims , idims = _get_info (dims , buf_len )
489
475
elif isinstance (src , list ):
490
476
tmp = host .array ('f' , src )
491
- buf , buf_len = tmp .buffer_info ()
477
+ buf ,buf_len = tmp .buffer_info ()
492
478
_type_char = tmp .typecode
493
479
numdims , idims = _get_info (dims , buf_len )
494
480
elif isinstance (src , int ) or isinstance (src , c_void_ptr_t ):
@@ -512,7 +498,7 @@ def __init__(self, src=None, dims=None, dtype=None, is_device=False, offset=None
512
498
raise TypeError ("src is an object of unsupported class" )
513
499
514
500
if (type_char is not None and
515
- type_char != _type_char ):
501
+ type_char != _type_char ):
516
502
raise TypeError ("Can not create array of requested type from input data type" )
517
503
if (offset is None and strides is None ):
518
504
self .arr = _create_array (buf , numdims , idims , to_dtype [_type_char ], is_device )
@@ -634,8 +620,8 @@ def strides(self):
634
620
s2 = c_dim_t (0 )
635
621
s3 = c_dim_t (0 )
636
622
safe_call (backend .get ().af_get_strides (c_pointer (s0 ), c_pointer (s1 ),
637
- c_pointer (s2 ), c_pointer (s3 ), self .arr ))
638
- strides = (s0 .value , s1 .value , s2 .value , s3 .value )
623
+ c_pointer (s2 ), c_pointer (s3 ), self .arr ))
624
+ strides = (s0 .value ,s1 .value ,s2 .value ,s3 .value )
639
625
return strides [:self .numdims ()]
640
626
641
627
def elements (self ):
@@ -694,8 +680,8 @@ def dims(self):
694
680
d2 = c_dim_t (0 )
695
681
d3 = c_dim_t (0 )
696
682
safe_call (backend .get ().af_get_dims (c_pointer (d0 ), c_pointer (d1 ),
697
- c_pointer (d2 ), c_pointer (d3 ), self .arr ))
698
- dims = (d0 .value , d1 .value , d2 .value , d3 .value )
683
+ c_pointer (d2 ), c_pointer (d3 ), self .arr ))
684
+ dims = (d0 .value ,d1 .value ,d2 .value ,d3 .value )
699
685
return dims [:self .numdims ()]
700
686
701
687
@property
@@ -920,7 +906,7 @@ def __itruediv__(self, other):
920
906
"""
921
907
Perform self /= other.
922
908
"""
923
- self = _binary_func (self , other , backend .get ().af_div )
909
+ self = _binary_func (self , other , backend .get ().af_div )
924
910
return self
925
911
926
912
def __rtruediv__ (self , other ):
@@ -939,7 +925,7 @@ def __idiv__(self, other):
939
925
"""
940
926
Perform other / self.
941
927
"""
942
- self = _binary_func (self , other , backend .get ().af_div )
928
+ self = _binary_func (self , other , backend .get ().af_div )
943
929
return self
944
930
945
931
def __rdiv__ (self , other ):
@@ -958,7 +944,7 @@ def __imod__(self, other):
958
944
"""
959
945
Perform self %= other.
960
946
"""
961
- self = _binary_func (self , other , backend .get ().af_mod )
947
+ self = _binary_func (self , other , backend .get ().af_mod )
962
948
return self
963
949
964
950
def __rmod__ (self , other ):
@@ -977,7 +963,7 @@ def __ipow__(self, other):
977
963
"""
978
964
Perform self **= other.
979
965
"""
980
- self = _binary_func (self , other , backend .get ().af_pow )
966
+ self = _binary_func (self , other , backend .get ().af_pow )
981
967
return self
982
968
983
969
def __rpow__ (self , other ):
@@ -1120,15 +1106,15 @@ def logical_and(self, other):
1120
1106
Return self && other.
1121
1107
"""
1122
1108
out = Array ()
1123
- safe_call (backend .get ().af_and (c_pointer (out .arr ), self .arr , other .arr )) # TODO: bcast var?
1109
+ safe_call (backend .get ().af_and (c_pointer (out .arr ), self .arr , other .arr )) # TODO: bcast var?
1124
1110
return out
1125
1111
1126
1112
def logical_or (self , other ):
1127
1113
"""
1128
1114
Return self || other.
1129
1115
"""
1130
1116
out = Array ()
1131
- safe_call (backend .get ().af_or (c_pointer (out .arr ), self .arr , other .arr )) # TODO: bcast var?
1117
+ safe_call (backend .get ().af_or (c_pointer (out .arr ), self .arr , other .arr )) # TODO: bcast var?
1132
1118
return out
1133
1119
1134
1120
def __nonzero__ (self ):
@@ -1158,11 +1144,12 @@ def __getitem__(self, key):
1158
1144
inds = _get_indices (key )
1159
1145
1160
1146
safe_call (backend .get ().af_index_gen (c_pointer (out .arr ),
1161
- self .arr , c_dim_t (n_dims ), inds .pointer ))
1147
+ self .arr , c_dim_t (n_dims ), inds .pointer ))
1162
1148
return out
1163
1149
except RuntimeError as e :
1164
1150
raise IndexError (str (e ))
1165
1151
1152
+
1166
1153
def __setitem__ (self , key , val ):
1167
1154
"""
1168
1155
Perform self[key] = val
@@ -1188,14 +1175,14 @@ def __setitem__(self, key, val):
1188
1175
n_dims = 1
1189
1176
other_arr = constant_array (val , int (num ), dtype = self .type ())
1190
1177
else :
1191
- other_arr = constant_array (val , tdims [0 ], tdims [1 ], tdims [2 ], tdims [3 ], self .type ())
1178
+ other_arr = constant_array (val , tdims [0 ] , tdims [1 ], tdims [2 ], tdims [3 ], self .type ())
1192
1179
del_other = True
1193
1180
else :
1194
1181
other_arr = val .arr
1195
1182
del_other = False
1196
1183
1197
1184
out_arr = c_void_ptr_t (0 )
1198
- inds = _get_indices (key )
1185
+ inds = _get_indices (key )
1199
1186
1200
1187
safe_call (backend .get ().af_assign_gen (c_pointer (out_arr ),
1201
1188
self .arr , c_dim_t (n_dims ), inds .pointer ,
@@ -1414,7 +1401,6 @@ def to_ndarray(self, output=None):
1414
1401
safe_call (backend .get ().af_get_data_ptr (c_void_ptr_t (output .ctypes .data ), tmp .arr ))
1415
1402
return output
1416
1403
1417
-
1418
1404
def display (a , precision = 4 ):
1419
1405
"""
1420
1406
Displays the contents of an array.
@@ -1440,7 +1426,6 @@ def display(a, precision=4):
1440
1426
safe_call (backend .get ().af_print_array_gen (name .encode ('utf-8' ),
1441
1427
a .arr , c_int_t (precision )))
1442
1428
1443
-
1444
1429
def save_array (key , a , filename , append = False ):
1445
1430
"""
1446
1431
Save an array to disk.
@@ -1472,7 +1457,6 @@ def save_array(key, a, filename, append=False):
1472
1457
append ))
1473
1458
return index .value
1474
1459
1475
-
1476
1460
def read_array (filename , index = None , key = None ):
1477
1461
"""
1478
1462
Read an array from disk.
@@ -1506,3 +1490,6 @@ def read_array(filename, index=None, key=None):
1506
1490
key .encode ('utf-8' )))
1507
1491
1508
1492
return out
1493
+
1494
+ from .algorithm import (sum , count )
1495
+ from .arith import cast
0 commit comments