Skip to content

Commit 52b60ad

Browse files
authored
Feature: allow colon indexing of traced **vectors** (#664)
* Feature: allow colon indexing of traced vectors * Style: fix space in mult op
1 parent 9a1179d commit 52b60ad

File tree

2 files changed

+27
-6
lines changed

2 files changed

+27
-6
lines changed

src/TracedRArray.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -268,12 +268,6 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices) where {T,N}
268268
return a
269269
end
270270

271-
function Base.setindex!(a::TracedRArray{T,N}, v, ::Colon) where {T,N}
272-
v = TracedUtils.broadcast_to_size(v, size(a))
273-
set_mlir_data!(a, get_mlir_data(v))
274-
return a
275-
end
276-
277271
function Base.setindex!(a::TracedRArray{T,N}, v, indices::CartesianIndex{N}) where {T,N}
278272
GPUArraysCore.assertscalar("setindex!(::TracedRArray, v, ::CartesianIndex{N})")
279273
indices =
@@ -293,6 +287,16 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::CartesianIndex{N}) whe
293287
end
294288

295289
function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {T,N}
290+
if (N == 1) && (indices isa Colon)
291+
# Remove ambiguity from the previous
292+
# ```julia
293+
# Base.setindex!(a::TracedRArray{T,N}, v, ::Colon) where {T,N}
294+
# ```
295+
# signature, which would be confused with this one for N=1.
296+
v = TracedUtils.broadcast_to_size(v, size(a))
297+
set_mlir_data!(a, get_mlir_data(v))
298+
return a
299+
end
296300
maybe_assert_scalar_setindexing(a, indices...)
297301

298302
indices = TracedUtils.normalize_indices(a, indices...)

test/indexing.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,23 @@ end
3131
# get_view_compiled = @compile get_view(x_concrete)
3232
end
3333

34+
function maskset!(y, x)
35+
y[:] = x
36+
return nothing
37+
end
38+
39+
@testset "setindex! with vectors & colon indexing" begin
40+
x = Reactant.to_rarray([4.0])
41+
y = Reactant.to_rarray([2.0])
42+
@jit(maskset!(y, x))
43+
@test y x
44+
45+
x = Reactant.to_rarray(ones(3))
46+
y = Reactant.to_rarray(2 * ones(3))
47+
@jit(maskset!(y, x))
48+
@test y x
49+
end
50+
3451
function masking(x)
3552
y = similar(x)
3653
y[1:2, :] .= 0

0 commit comments

Comments
 (0)