Skip to content

Commit

Permalink
Add support for doing a masked fill from a tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
AngelEzquerra committed Nov 1, 2023
1 parent 12610a3 commit f01082d
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 15 deletions.
61 changes: 60 additions & 1 deletion src/arraymancer/tensor/selectors.nim
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ proc masked_fill*[T](t: var Tensor[T], mask: Tensor[bool], value: T) =
if maskElem:
tElem = value


proc masked_fill*[T](t: var Tensor[T], mask: openArray, value: T) =
## For the index of each element of t.
## Fill the elements at ``t[index]`` with the ``value``
Expand All @@ -179,6 +178,66 @@ proc masked_fill*[T](t: var Tensor[T], mask: openArray, value: T) =
return
t.masked_fill(mask.toTensor(), value)

proc masked_fill*[T](t: var Tensor[T], mask: Tensor[bool], value: Tensor[T]) =
## For the index of each element of t.
## Fill the elements at ``t[index]`` with the ``value``
## if their corresponding ``mask[index]`` is true.
## If not they are untouched.
##
## Example:
##
## t.masked_fill(t > 0, -1)
##
## or alternatively:
##
## t.masked_fill(t > 0): -1
if t.size == 0 or mask.size == 0:
return
check_elementwise(t, mask)

# Due to requiring unnecessary assigning `x` for a `false` mask
# apply2_inline is a bit slower for very sparse mask.
# As this is a critical operation, especially on dataframes, we use the lower level construct.
#
# t.apply2_inline(mask):
# if y:
# value
# else:
# x
omp_parallel_blocks(block_offset, block_size, t.size):
var n = block_offset
try:
for tElem, maskElem in mzip(t, mask, block_offset, block_size):
if maskElem:
tElem = value[n]
inc n
except IndexDefect:
raise newException(IndexDefect, "The size of the value tensor (" & $value.size &
") is smaller than the number of true elements in the mask (" & $mask.size & ")")

proc masked_fill*[T](t: var Tensor[T], mask: openArray, value: Tensor[T]) =
## For the index of each element of t.
## Fill the elements at ``t[index]`` with the ``value``
## if their corresponding ``mask[index]`` is true.
## If not they are untouched.
##
## Example:
##
## t.masked_fill(t > 0, -1)
##
## or alternatively:
##
## t.masked_fill(t > 0): -1
##
## The boolean mask must be
## - an array or sequence of bools
## - an array of arrays of bools,
## - ...
##
if t.size == 0 or mask.len == 0:
return
t.masked_fill(mask.toTensor(), value)

# Mask axis
# --------------------------------------------------------------------------------------------

Expand Down
36 changes: 22 additions & 14 deletions tests/tensor/test_selectors.nim
Original file line number Diff line number Diff line change
Expand Up @@ -164,26 +164,34 @@ proc main() =
check: r == expected

test "Masked_fill":
block: # Numpy reference doc
# https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#boolean-array-indexing
# select non NaN
# x = np.array([[1., 2.], [np.nan, 3.], [np.nan, np.nan]])
# x[np.isnan(x)] = -1
# x
# np.array([[1., 2.], [-1, 3.], [-1, -1]])
var x = [[1.0, 2.0],
[NaN, 3.0],
[NaN, NaN]].toTensor
# Numpy reference doc
# https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#boolean-array-indexing
# select non NaN
# x = np.array([[1., 2.], [np.nan, 3.], [np.nan, np.nan]])
# x[np.isnan(x)] = -1
# x
# np.array([[1., 2.], [-1, 3.], [-1, -1]])
let t = [[1.0, 2.0],
[NaN, 3.0],
[NaN, NaN]].toTensor
block: # Single value masked fill
var x = t.clone()

x.masked_fill(x.isNaN, -1.0)

let expected = [[1.0, 2.0], [-1.0, 3.0], [-1.0, -1.0]].toTensor()
check: x == expected

block: # with regular arrays/sequences
var x = [[1.0, 2.0],
[NaN, 3.0],
[NaN, NaN]].toTensor
block: # Multiple value masked fill
var x = t.clone()

x.masked_fill(x.isNaN, [-10.0, -20.0, -30.0].toTensor())

let expected = [[1.0, 2.0], [-10.0, 3.0], [-20.0, -30.0]].toTensor()
check: x == expected

block: # Fill with regular arrays/sequences
var x = t.clone()

x.masked_fill(
[[false, false],
Expand Down

0 comments on commit f01082d

Please sign in to comment.