Skip to content

Commit

Permalink
Fix bugs in Python 3.9
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jun 11, 2024
1 parent 6fc6add commit 08b90cd
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion brainunit/math/_compat_numpy_array_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ def array_split(ary: Union[Quantity, bst.typing.ArrayLike],
'''
if isinstance(ary, Quantity):
return [Quantity(x, dim=ary.dim) for x in jnp.array_split(ary.value, indices_or_sections, axis)]
elif isinstance(ary, bst.typing.ArrayLike):
elif isinstance(ary, (jax.Array, np.ndarray)):
return jnp.array_split(ary, indices_or_sections, axis)
else:
raise ValueError(f'Unsupported type: {type(ary)} for array_split')
Expand Down
2 changes: 1 addition & 1 deletion brainunit/math/_compat_numpy_funcs_bit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def wrap_elementwise_bit_operation_binary(func):
def f(x, y, *args, **kwargs):
if isinstance(x, Quantity) or isinstance(y, Quantity):
raise ValueError(f'Expected integers, got {x} and {y}')
elif isinstance(x, bst.typing.ArrayLike) and isinstance(y, bst.typing.ArrayLike):
elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)):
return func(x, y, *args, **kwargs)
else:
raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}')
Expand Down

0 comments on commit 08b90cd

Please sign in to comment.