Skip to content

Commit

Permalink
add index_fill taking sequence / array / tensor
Browse files Browse the repository at this point in the history
To allow

```nim
t[[0, 2]] = [1, 2] # optionally seq or tensor
t[@[0, 2]] = [1, 2]
t[toTensor [0, 2]] = [1, 2]
```

filling multiple indices at same time from a array/seq/tensor.

Previously this only supported to set all indices given to the same value.
  • Loading branch information
Vindaar committed May 28, 2024
1 parent 3211c4e commit 785cb7e
Showing 1 changed file with 32 additions and 1 deletion.
33 changes: 32 additions & 1 deletion src/arraymancer/tensor/selectors.nim
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,40 @@ proc index_select*[T; Idx: byte or char or SomeInteger](t: Tensor[T], axis: int,
var t_slice = t.atAxisIndex(axis, int(index))
r_slice.copyFrom(t_slice)

proc index_fill*[T; Idx: byte or char or SomeInteger](t: var Tensor[T], axis: int, indices: Tensor[Idx], value: T) =
template index_fill_vector_body(): untyped {.dirty.} =
if t.len == 0 or indices.len == 0:
return
if indices.len != values.len:
raise newException(ValueError, "Cannot assign values to indices, because numbers mismatch: " &
"# indices = " & $indices.len & ", # values = " & $values.len)
when typeof(indices) isnot Tensor:
template enumerate(arg): untyped {.gensym.} = pairs(arg)
for i, index in enumerate(indices):
var t_slice = t.atAxisIndex(axis, int(index))
for old_val in t_slice.mitems():
old_val = values[i]

# These are a bit terrible, but trying to overload `openArray[Idx] | Tensor[Idx]` for example doesn't seem to work
proc index_fill*[T; Idx: byte or char or SomeInteger](t: var Tensor[T], axis: int, indices: openArray[Idx], values: openArray[T]) =
## Replace elements of `t` indicated by their `indices` along `axis` with `value`
## This is equivalent to Numpy `put`.
index_fill_vector_body()

proc index_fill*[T; Idx: byte or char or SomeInteger](t: var Tensor[T], axis: int, indices: Tensor[Idx], values: openArray[T]) =
## Replace elements of `t` indicated by their `indices` along `axis` with `value`
## This is equivalent to Numpy `put`.
index_fill_vector_body()

proc index_fill*[T; Idx: byte or char or SomeInteger](t: var Tensor[T], axis: int, indices: openArray[Idx], values: Tensor[T]) =
## Replace elements of `t` indicated by their `indices` along `axis` with `value`
## This is equivalent to Numpy `put`.
index_fill_vector_body()

proc index_fill*[T; Idx: byte or char or SomeInteger](t: var Tensor[T], axis: int, indices: Tensor[Idx], values: Tensor[T]) =
## Replace elements of `t` indicated by their `indices` along `axis` with `value`
## This is equivalent to Numpy `put`.
index_fill_vector_body()

template index_fill_scalar_body(): untyped {.dirty.} =
if t.size == 0 or indices.size == 0:
return
Expand Down

0 comments on commit 785cb7e

Please sign in to comment.