diff --git a/arrayfire/array.py b/arrayfire/array.py index beb7a0076..7b9fcff16 100644 --- a/arrayfire/array.py +++ b/arrayfire/array.py @@ -122,7 +122,7 @@ def _create_empty_array(numdims, idims, dtype): numdims, c_pointer(c_dims), dtype.value)) return out_arr -def constant_array(val, d0, d1=None, d2=None, d3=None, dtype=Dtype.f32): +def constant_array(val, d0, d1=None, d2=None, d3=None, dtype=set_global_precision()): """ Internal function to create a C array. Should not be used externall. """ diff --git a/arrayfire/data.py b/arrayfire/data.py index 73f516073..f87142478 100644 --- a/arrayfire/data.py +++ b/arrayfire/data.py @@ -18,7 +18,7 @@ from .util import _is_number from .random import randu, randn, set_seed, get_seed -def constant(val, d0, d1=None, d2=None, d3=None, dtype=Dtype.f32): +def constant(val, d0, d1=None, d2=None, d3=None, dtype=set_global_precision()): """ Create a multi dimensional array whose elements contain the same value. @@ -60,7 +60,7 @@ def constant(val, d0, d1=None, d2=None, d3=None, dtype=Dtype.f32): # Store builtin range function to be used later _brange = range -def range(d0, d1=None, d2=None, d3=None, dim=0, dtype=Dtype.f32): +def range(d0, d1=None, d2=None, d3=None, dim=0, dtype=set_global_precision()): """ Create a multi dimensional array using length of a dimension as range. @@ -122,7 +122,7 @@ def range(d0, d1=None, d2=None, d3=None, dim=0, dtype=Dtype.f32): return out -def iota(d0, d1=None, d2=None, d3=None, dim=-1, tile_dims=None, dtype=Dtype.f32): +def iota(d0, d1=None, d2=None, d3=None, dim=-1, tile_dims=None, dtype=set_global_precision()): """ Create a multi dimensional array using the number of elements in the array as the range. @@ -187,7 +187,7 @@ def iota(d0, d1=None, d2=None, d3=None, dim=-1, tile_dims=None, dtype=Dtype.f32) 4, c_pointer(tdims), dtype.value)) return out -def identity(d0, d1, d2=None, d3=None, dtype=Dtype.f32): +def identity(d0, d1, d2=None, d3=None, dtype=set_global_precision()): """ Create an identity matrix or batch of identity matrices. diff --git a/arrayfire/image.py b/arrayfire/image.py index 5fadd991a..58a9fc0eb 100644 --- a/arrayfire/image.py +++ b/arrayfire/image.py @@ -458,7 +458,7 @@ def dilate(image, mask = None): """ if mask is None: - mask = constant(1, 3, 3, dtype=Dtype.f32) + mask = constant(1, 3, 3, dtype=set_global_precision()) output = Array() safe_call(backend.get().af_dilate(c_pointer(output.arr), image.arr, mask.arr)) @@ -487,7 +487,7 @@ def dilate3(volume, mask = None): """ if mask is None: - mask = constant(1, 3, 3, 3, dtype=Dtype.f32) + mask = constant(1, 3, 3, 3, dtype=set_global_precision()) output = Array() safe_call(backend.get().af_dilate3(c_pointer(output.arr), volume.arr, mask.arr)) @@ -516,7 +516,7 @@ def erode(image, mask = None): """ if mask is None: - mask = constant(1, 3, 3, dtype=Dtype.f32) + mask = constant(1, 3, 3, dtype=set_global_precision()) output = Array() safe_call(backend.get().af_erode(c_pointer(output.arr), image.arr, mask.arr)) @@ -546,7 +546,7 @@ def erode3(volume, mask = None): """ if mask is None: - mask = constant(1, 3, 3, 3, dtype=Dtype.f32) + mask = constant(1, 3, 3, 3, dtype=set_global_precision()) output = Array() safe_call(backend.get().af_erode3(c_pointer(output.arr), volume.arr, mask.arr)) diff --git a/arrayfire/library.py b/arrayfire/library.py index 970f7870f..ed36aa43a 100644 --- a/arrayfire/library.py +++ b/arrayfire/library.py @@ -702,4 +702,20 @@ def get_size_of(dtype): safe_call(backend.get().af_get_size_of(c_pointer(size), dtype.value)) return size.value +precision_setting = 'f64' + +def set_global_precision(precision = precision_setting): + + global global_precision + + if(precision == 'f64'): + global_precision = Dtype.f64 + elif(precision == 'f32'): + global_precision = Dtype.f32 + else: + print("Unrecognized precision option. Defaulting to double precision") + global_precision = Dtype.f64 + + return global_precision + from .util import safe_call diff --git a/arrayfire/random.py b/arrayfire/random.py index 179833db8..7687b1799 100644 --- a/arrayfire/random.py +++ b/arrayfire/random.py @@ -81,7 +81,7 @@ def get_seed(self): safe_call(backend.get().af_random_engine_get_seed(c_pointer(seed), self.engine)) return seed.value -def randu(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32, engine=None): +def randu(d0, d1=None, d2=None, d3=None, dtype=set_global_precision(), engine=None): """ Create a multi dimensional array containing values from a uniform distribution. @@ -125,7 +125,7 @@ def randu(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32, engine=None): return out -def randn(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32, engine=None): +def randn(d0, d1=None, d2=None, d3=None, dtype=set_global_precision(), engine=None): """ Create a multi dimensional array containing values from a normal distribution.