Skip to content

Commit

Permalink
cast to float dtypes for nansum (ivy-llc#10512)
Browse files Browse the repository at this point in the history
  • Loading branch information
fnhirwa authored Feb 13, 2023
1 parent 84bedfb commit cb0d344
Showing 1 changed file with 4 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -477,10 +477,13 @@ def _get_castable_dtypes_values(draw, *, allow_nan=False):
)
axis = draw(helpers.get_axis(shape=shape, force_int=True))
dtype1, values, dtype2 = draw(
helpers.get_castable_dtype(draw(available_dtypes), dtype[0], values[0])
helpers.get_castable_dtype(
draw(helpers.get_dtypes("float")), dtype[0], values[0]
)
)
return [dtype1], [values], axis, dtype2


# nansum
@handle_test(
fn_tree="functional.ivy.experimental.nansum",
Expand Down

0 comments on commit cb0d344

Please sign in to comment.