Skip to content

Commit

Permalink
refactor: Replaced lambda expressions with def functions (ivy-llc…
Browse files Browse the repository at this point in the history
  • Loading branch information
Sai-Suraj-27 authored Dec 25, 2023
1 parent 29ee162 commit 580db06
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 33 deletions.
5 changes: 4 additions & 1 deletion ivy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,10 @@ def dynamic_backend_as(value):
downcast_dtypes = False
upcast_dtypes = False
crosscast_dtypes = False
cast_dtypes = lambda: downcast_dtypes and upcast_dtypes and crosscast_dtypes


def cast_dtypes():
return downcast_dtypes and upcast_dtypes and crosscast_dtypes


def downcast_data_types(val=True):
Expand Down
10 changes: 7 additions & 3 deletions ivy/func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,13 @@ def try_array_function_override(func, overloaded_args, types, args, kwargs):

def _get_first_array(*args, **kwargs):
# ToDo: make this more efficient, with function ivy.nested_nth_index_where
array_fn = lambda x: (
ivy.is_array(x) if not hasattr(x, "_ivy_array") else ivy.is_array(x.ivy_array)
)
def array_fn(x):
return (
ivy.is_array(x)
if not hasattr(x, "_ivy_array")
else ivy.is_array(x.ivy_array)
)

array_fn = array_fn if "array_fn" not in kwargs else kwargs["array_fn"]
arr = None
if args:
Expand Down
4 changes: 3 additions & 1 deletion ivy/functional/frontends/jax/numpy/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,9 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
if axis < 0:
axis = axis + ndim

func = lambda elem: func1d(elem, *args, **kwargs)
def func(elem):
return func1d(elem, *args, **kwargs)

for i in range(1, ndim - axis):
func = ivy.vmap(func, in_axes=i, out_axes=-1)
for i in range(axis):
Expand Down
33 changes: 24 additions & 9 deletions ivy/functional/frontends/numpy/func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ def _assert_array(args, dtype, scalar_check=False, casting="safe"):
if ivy.is_bool_dtype(dtype):
assert_fn = ivy.is_bool_dtype
if ivy.is_int_dtype(dtype):
assert_fn = lambda x: not ivy.is_float_dtype(x)

def assert_fn(x): # noqa F811
return not ivy.is_float_dtype(x)

if assert_fn:
ivy.utils.assertions.check_all_or_any_fn(
*args,
Expand All @@ -51,13 +54,19 @@ def _assert_no_array(args, dtype, scalar_check=False, none=False):
if args:
first_arg = args[0]
fn_func = ivy.as_ivy_dtype(dtype) if ivy.exists(dtype) else ivy.dtype(first_arg)
assert_fn = lambda x: ivy.dtype(x) == fn_func

def assert_fn(x):
return ivy.dtype(x) == fn_func

if scalar_check:
assert_fn = lambda x: (
ivy.dtype(x) == fn_func
if ivy.shape(x) != ()
else _casting_no_special_case(ivy.dtype(x), fn_func, none)
)

def assert_fn(x): # noqa F811
return (
ivy.dtype(x) == fn_func
if ivy.shape(x) != ()
else _casting_no_special_case(ivy.dtype(x), fn_func, none)
)

ivy.utils.assertions.check_all_or_any_fn(
*args,
fn=assert_fn,
Expand Down Expand Up @@ -105,9 +114,15 @@ def _assert_scalar(args, dtype):
if args and dtype:
assert_fn = None
if ivy.is_int_dtype(dtype):
assert_fn = lambda x: not isinstance(x, float)

def assert_fn(x): # noqa F811
return not isinstance(x, float)

elif ivy.is_bool_dtype(dtype):
assert_fn = lambda x: isinstance(x, bool)

def assert_fn(x):
return isinstance(x, bool)

if assert_fn:
ivy.utils.assertions.check_all_or_any_fn(
*args,
Expand Down
15 changes: 12 additions & 3 deletions ivy/functional/ivy/experimental/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1689,11 +1689,20 @@ def area_interpolate(x, dims, size, scale):
def get_interpolate_kernel(mode):
kernel_func = _triangle_kernel
if mode == "tf_bicubic":
kernel_func = lambda inputs: _cubic_kernel(inputs)

def kernel_func(inputs): # noqa F811
return _cubic_kernel(inputs)

elif mode == "lanczos3":
kernel_func = lambda inputs: _lanczos_kernel(3, inputs)

def kernel_func(inputs):
return _lanczos_kernel(3, inputs)

elif mode == "lanczos5":
kernel_func = lambda inputs: _lanczos_kernel(5, inputs)

def kernel_func(inputs):
return _lanczos_kernel(5, inputs)

return kernel_func


Expand Down
53 changes: 37 additions & 16 deletions ivy/utils/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,19 @@ def _broadcast_inputs(x1, x2):


def check_less(x1, x2, allow_equal=False, message="", as_array=True):
comp_fn = lambda x1, x2: (ivy.any(x1 > x2), ivy.any(x1 >= x2))
def comp_fn(x1, x2):
return ivy.any(x1 > x2), ivy.any(x1 >= x2)

if not as_array:
iter_comp_fn = lambda x1_, x2_: (
any(x1 > x2 for x1, x2 in zip(x1_, x2_)),
any(x1 >= x2 for x1, x2 in zip(x1_, x2_)),
)
comp_fn = lambda x1, x2: iter_comp_fn(*_broadcast_inputs(x1, x2))

def iter_comp_fn(x1_, x2_):
return any(x1 > x2 for x1, x2 in zip(x1_, x2_)), any(
x1 >= x2 for x1, x2 in zip(x1_, x2_)
)

def comp_fn(x1, x2): # noqa F811
return iter_comp_fn(*_broadcast_inputs(x1, x2))

gt, gt_eq = comp_fn(x1, x2)
# less_equal
if allow_equal and gt:
Expand All @@ -42,13 +48,19 @@ def check_less(x1, x2, allow_equal=False, message="", as_array=True):


def check_greater(x1, x2, allow_equal=False, message="", as_array=True):
comp_fn = lambda x1, x2: (ivy.any(x1 < x2), ivy.any(x1 <= x2))
def comp_fn(x1, x2):
return ivy.any(x1 < x2), ivy.any(x1 <= x2)

if not as_array:
iter_comp_fn = lambda x1_, x2_: (
any(x1 < x2 for x1, x2 in zip(x1_, x2_)),
any(x1 <= x2 for x1, x2 in zip(x1_, x2_)),
)
comp_fn = lambda x1, x2: iter_comp_fn(*_broadcast_inputs(x1, x2))

def iter_comp_fn(x1_, x2_):
return any(x1 < x2 for x1, x2 in zip(x1_, x2_)), any(
x1 <= x2 for x1, x2 in zip(x1_, x2_)
)

def comp_fn(x1, x2): # noqa F811
return iter_comp_fn(*_broadcast_inputs(x1, x2))

lt, lt_eq = comp_fn(x1, x2)
# greater_equal
if allow_equal and lt:
Expand All @@ -63,11 +75,20 @@ def check_greater(x1, x2, allow_equal=False, message="", as_array=True):

def check_equal(x1, x2, inverse=False, message="", as_array=True):
# not_equal
eq_fn = lambda x1, x2: (x1 == x2 if inverse else x1 != x2)
comp_fn = lambda x1, x2: ivy.any(eq_fn(x1, x2))
def eq_fn(x1, x2):
return x1 == x2 if inverse else x1 != x2

def comp_fn(x1, x2):
return ivy.any(eq_fn(x1, x2))

if not as_array:
iter_comp_fn = lambda x1_, x2_: any(eq_fn(x1, x2) for x1, x2 in zip(x1_, x2_))
comp_fn = lambda x1, x2: iter_comp_fn(*_broadcast_inputs(x1, x2))

def iter_comp_fn(x1_, x2_):
return any(eq_fn(x1, x2) for x1, x2 in zip(x1_, x2_))

def comp_fn(x1, x2): # noqa F811
return iter_comp_fn(*_broadcast_inputs(x1, x2))

eq = comp_fn(x1, x2)
if inverse and eq:
raise ivy.utils.exceptions.IvyException(
Expand Down

0 comments on commit 580db06

Please sign in to comment.