Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Dec 4, 2024
1 parent 5ccf35c commit 46be2e3
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 10 deletions.
8 changes: 4 additions & 4 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
# functions for checking
'check_dims',
'check_units',
'handle_units',
'assign_units',
'fail_for_dimension_mismatch',
'fail_for_unit_mismatch',
'assert_quantity',
Expand Down Expand Up @@ -4455,12 +4455,12 @@ def new_f(*args, **kwds):


@set_module_as('brainunit')
def handle_units(**au):
def assign_units(**au):
"""
Decorator to transform units of arguments passed to a function
"""

def do_handle_units(f):
def do_assign_units(f):
@wraps(f)
def new_f(*args, **kwds):
newkeyset = kwds.copy()
Expand Down Expand Up @@ -4530,7 +4530,7 @@ def new_f(*args, **kwds):

return new_f

return do_handle_units
return do_assign_units

def _check_unit(f, val, unit):
unit = UNITLESS if unit is None else unit
Expand Down
32 changes: 26 additions & 6 deletions brainunit/_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,12 +1395,12 @@ def d_function2(true_result):
with pytest.raises(u.UnitMismatchError):
d_function2(2)

def test_handle_units(self):
def test_assign_units(self):
"""
Test the handle_units decorator
Test the assign_units decorator
"""

@u.handle_units(v=volt)
@u.assign_units(v=volt)
def a_function(v, x):
"""
v has to have units of volt, x can have any (or no) unit.
Expand All @@ -1421,7 +1421,7 @@ def a_function(v, x):
with pytest.raises(TypeError):
a_function(object(), None)

@u.handle_units(result=second)
@u.assign_units(result=second)
def b_function():
"""
Return a value in seconds if return_second is True, otherwise return
Expand All @@ -1432,7 +1432,7 @@ def b_function():
# Should work (returns second)
assert b_function() == 5 * second

@u.handle_units(a=bool, b=1, result=bool)
@u.assign_units(a=bool, b=1, result=bool)
def c_function(a, b):
if a:
return b > 0
Expand All @@ -1447,14 +1447,34 @@ def c_function(a, b):
c_function(1 * mV, 1)

# Multiple results
@u.handle_units(result=(second, volt))
@u.assign_units(result=(second, volt))
def d_function():
return 5, 3

# Should work (returns second)
assert d_function()[0] == 5 * second
assert d_function()[1] == 3 * volt

# Multiple results
@u.assign_units(result={'u': second, 'v': (volt, metre)})
def d_function2(true_result):
"""
Return a value in seconds if return_second is True, otherwise return
a value in volt.
"""
if true_result == 0:
return {'u': 5, 'v': (3, 2)}
elif true_result == 1:
return 3, 5
else:
return 3, 5

# Should work (returns dict)
d_function2(0)
# Should fail (returns tuple)
with pytest.raises(TypeError):
d_function2(1)



def test_str_repr():
Expand Down

0 comments on commit 46be2e3

Please sign in to comment.