Skip to content

Added global precision option. Default to f64 #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion arrayfire/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _create_empty_array(numdims, idims, dtype):
numdims, c_pointer(c_dims), dtype.value))
return out_arr

def constant_array(val, d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
def constant_array(val, d0, d1=None, d2=None, d3=None, dtype=set_global_precision()):
"""
Internal function to create a C array. Should not be used externall.
"""
Expand Down
8 changes: 4 additions & 4 deletions arrayfire/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .util import _is_number
from .random import randu, randn, set_seed, get_seed

def constant(val, d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
def constant(val, d0, d1=None, d2=None, d3=None, dtype=set_global_precision()):
"""
Create a multi dimensional array whose elements contain the same value.

Expand Down Expand Up @@ -60,7 +60,7 @@ def constant(val, d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
# Store builtin range function to be used later
_brange = range

def range(d0, d1=None, d2=None, d3=None, dim=0, dtype=Dtype.f32):
def range(d0, d1=None, d2=None, d3=None, dim=0, dtype=set_global_precision()):
"""
Create a multi dimensional array using length of a dimension as range.

Expand Down Expand Up @@ -122,7 +122,7 @@ def range(d0, d1=None, d2=None, d3=None, dim=0, dtype=Dtype.f32):
return out


def iota(d0, d1=None, d2=None, d3=None, dim=-1, tile_dims=None, dtype=Dtype.f32):
def iota(d0, d1=None, d2=None, d3=None, dim=-1, tile_dims=None, dtype=set_global_precision()):
"""
Create a multi dimensional array using the number of elements in the array as the range.

Expand Down Expand Up @@ -187,7 +187,7 @@ def iota(d0, d1=None, d2=None, d3=None, dim=-1, tile_dims=None, dtype=Dtype.f32)
4, c_pointer(tdims), dtype.value))
return out

def identity(d0, d1, d2=None, d3=None, dtype=Dtype.f32):
def identity(d0, d1, d2=None, d3=None, dtype=set_global_precision()):
"""
Create an identity matrix or batch of identity matrices.

Expand Down
8 changes: 4 additions & 4 deletions arrayfire/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def dilate(image, mask = None):

"""
if mask is None:
mask = constant(1, 3, 3, dtype=Dtype.f32)
mask = constant(1, 3, 3, dtype=set_global_precision())

output = Array()
safe_call(backend.get().af_dilate(c_pointer(output.arr), image.arr, mask.arr))
Expand Down Expand Up @@ -487,7 +487,7 @@ def dilate3(volume, mask = None):

"""
if mask is None:
mask = constant(1, 3, 3, 3, dtype=Dtype.f32)
mask = constant(1, 3, 3, 3, dtype=set_global_precision())

output = Array()
safe_call(backend.get().af_dilate3(c_pointer(output.arr), volume.arr, mask.arr))
Expand Down Expand Up @@ -516,7 +516,7 @@ def erode(image, mask = None):

"""
if mask is None:
mask = constant(1, 3, 3, dtype=Dtype.f32)
mask = constant(1, 3, 3, dtype=set_global_precision())

output = Array()
safe_call(backend.get().af_erode(c_pointer(output.arr), image.arr, mask.arr))
Expand Down Expand Up @@ -546,7 +546,7 @@ def erode3(volume, mask = None):
"""

if mask is None:
mask = constant(1, 3, 3, 3, dtype=Dtype.f32)
mask = constant(1, 3, 3, 3, dtype=set_global_precision())

output = Array()
safe_call(backend.get().af_erode3(c_pointer(output.arr), volume.arr, mask.arr))
Expand Down
16 changes: 16 additions & 0 deletions arrayfire/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,4 +702,20 @@ def get_size_of(dtype):
safe_call(backend.get().af_get_size_of(c_pointer(size), dtype.value))
return size.value

precision_setting = 'f64'

def set_global_precision(precision = precision_setting):

global global_precision

if(precision == 'f64'):
global_precision = Dtype.f64
elif(precision == 'f32'):
global_precision = Dtype.f32
else:
print("Unrecognized precision option. Defaulting to double precision")
global_precision = Dtype.f64

return global_precision

from .util import safe_call
4 changes: 2 additions & 2 deletions arrayfire/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def get_seed(self):
safe_call(backend.get().af_random_engine_get_seed(c_pointer(seed), self.engine))
return seed.value

def randu(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32, engine=None):
def randu(d0, d1=None, d2=None, d3=None, dtype=set_global_precision(), engine=None):
"""
Create a multi dimensional array containing values from a uniform distribution.

Expand Down Expand Up @@ -125,7 +125,7 @@ def randu(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32, engine=None):

return out

def randn(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32, engine=None):
def randn(d0, d1=None, d2=None, d3=None, dtype=set_global_precision(), engine=None):
"""
Create a multi dimensional array containing values from a normal distribution.

Expand Down