Skip to content

Commit 0e764de

Browse files
Fix offsetarrays support (#464)
* Fix offsetarrays support * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * add test file --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent ce20b3c commit 0e764de

File tree

6 files changed

+40
-4
lines changed

6 files changed

+40
-4
lines changed

Project.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,19 @@ AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
2424
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
2525
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2626
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
27+
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
2728
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
2829
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
2930
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
3031
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3132
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
3233

33-
[sources.ReactantCore]
34-
path = "lib/ReactantCore"
35-
3634
[extensions]
3735
ReactantAbstractFFTsExt = "AbstractFFTs"
3836
ReactantArrayInterfaceExt = "ArrayInterface"
3937
ReactantCUDAExt = "CUDA"
4038
ReactantNNlibExt = "NNlib"
39+
ReactantOffsetArraysExt = "OffsetArrays"
4140
ReactantPythonCallExt = "PythonCall"
4241
ReactantRandom123Ext = "Random123"
4342
ReactantSpecialFunctionsExt = "SpecialFunctions"
@@ -75,3 +74,6 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
7574
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
7675
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
7776
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
77+
78+
[sources.ReactantCore]
79+
path = "lib/ReactantCore"

ext/ReactantOffsetArraysExt.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module ReactantOffsetArraysExt
2+
3+
using OffsetArrays: OffsetArray
4+
using Reactant: Reactant, MLIR, Ops, TracedRArray
5+
6+
function Reactant.traced_type(
7+
::Type{<:OffsetArray{<:Any,N,T}}, seen::ST, ::Val{mode}, track_numbers
8+
) where {T,N,ST,mode}
9+
T2 = Reactant.traced_type(T, seen, Val(mode), track_numbers)
10+
return OffsetArray{eltype(T2),N,T2}
11+
end
12+
13+
end

src/Tracing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,7 @@ function to_rarray_internal(@nospecialize(::TracedRArray), ::Tuple)
705705
end
706706
@inline to_rarray_internal(@nospecialize(x::ConcreteRArray), ::Tuple) = x
707707
@inline function to_rarray_internal(
708-
@nospecialize(x::AbstractArray{<:ReactantPrimitive}), ::Tuple
708+
@nospecialize(x::Array{<:ReactantPrimitive}), ::Tuple
709709
)
710710
return ConcreteRArray(x)
711711
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
1515
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
1616
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1717
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
18+
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
1819
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
1920
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
2021
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"

test/integration/offsetarrays.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using Reactant
2+
using Test
3+
using OffsetArrays
4+
5+
function scalar_index(x)
6+
@allowscalar getindex(x, -1, 0)
7+
end
8+
@testset "OffsetArrays" begin
9+
A = Float64.(reshape(1:15, 3, 5))
10+
OA = OffsetArray(A, -1:1, 0:4)
11+
rOA = Reactant.to_rarray(OA)
12+
13+
oval = scalar_index(OA)
14+
cval = scalar_index(rOA)
15+
@test cval oval
16+
17+
tval = @jit scalar_index(rOA)
18+
@test tval oval
19+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
6262
# Temporarily disabled as minutia are debugged
6363
# @safetestset "CUDA" include("integration/cuda.jl")
6464
@safetestset "Linear Algebra" include("integration/linear_algebra.jl")
65+
@safetestset "OffsetArrays" include("integration/offsetarrays.jl")
6566
@safetestset "AbstractFFTs" include("integration/fft.jl")
6667
@safetestset "SpecialFunctions" include("integration/special_functions.jl")
6768
@safetestset "Random" include("integration/random.jl")

0 commit comments

Comments
 (0)