Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 45 additions & 45 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
@@ -1,52 +1,52 @@
steps:
- group: ":test_tube: Tests"
steps:
- label: ":julia: :linux: CUDA Julia v{{matrix.version}} -- {{matrix.group}} -- {{matrix.runtime}}"
matrix:
setup:
version:
- "1.10"
group:
- core
- neural_networks
- integration
runtime:
- "PJRT"
- "IFRT"
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.version}}"
- JuliaCI/julia-coverage#v1:
codecov: true
dirs:
- src
- ext
- lib/ReactantCore/src
commands: |
touch LocalPreferences.toml
# steps:
# - group: ":test_tube: Tests"
# steps:
# - label: ":julia: :linux: CUDA Julia v{{matrix.version}} -- {{matrix.group}} -- {{matrix.runtime}}"
# matrix:
# setup:
# version:
# - "1.10"
# group:
# - core
# - neural_networks
# - integration
# runtime:
# - "PJRT"
# - "IFRT"
# plugins:
# - JuliaCI/julia#v1:
# version: "{{matrix.version}}"
# - JuliaCI/julia-coverage#v1:
# codecov: true
# dirs:
# - src
# - ext
# - lib/ReactantCore/src
# commands: |
# touch LocalPreferences.toml

echo "[Reactant]" >> LocalPreferences.toml
echo "xla_runtime = \"{{matrix.runtime}}\"" >> LocalPreferences.toml
# echo "[Reactant]" >> LocalPreferences.toml
# echo "xla_runtime = \"{{matrix.runtime}}\"" >> LocalPreferences.toml

cat LocalPreferences.toml
# cat LocalPreferences.toml

julia --project=. -e 'println("--- :julia: Instantiating project")
using Pkg
Pkg.develop([PackageSpec(path="lib/ReactantCore")])'
# julia --project=. -e 'println("--- :julia: Instantiating project")
# using Pkg
# Pkg.develop([PackageSpec(path="lib/ReactantCore")])'

julia --project=. -e 'println("--- :julia: Run Tests")
using Pkg
Pkg.test(; coverage="user")'
agents:
queue: "juliagpu"
cuda: "*"
env:
REACTANT_TEST_GROUP: "{{matrix.group}}"
JULIA_DEBUG: "Reactant,Reactant_jll"
CUDA_VISIBLE_DEVICES: 0
REACTANT_BACKEND_GROUP: "GPU"
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 120
# julia --project=. -e 'println("--- :julia: Run Tests")
# using Pkg
# Pkg.test(; coverage="user")'
# agents:
# queue: "juliagpu"
# cuda: "*"
# env:
# REACTANT_TEST_GROUP: "{{matrix.group}}"
# JULIA_DEBUG: "Reactant,Reactant_jll"
# CUDA_VISIBLE_DEVICES: 0
# REACTANT_BACKEND_GROUP: "GPU"
# if: build.message !~ /\[skip tests\]/
# timeout_in_minutes: 120

# - label: ":julia: :linux: AMDGPU Julia v{{matrix.version}} -- {{matrix.group}} -- {{matrix.runtime}}"
# matrix:
Expand Down
64 changes: 32 additions & 32 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,32 +34,32 @@ jobs:
fail-fast: false
matrix:
version:
- "1.10"
# - "1.10"
- "1.11"
# - 'nightly'
os:
- ubuntu-24.04
- ubuntu-latest
# `ubuntu-22.04-arm` is considered more stable than `ubuntu-24.04-arm`:
# <https://github.com/orgs/community/discussions/148648#discussioncomment-12099554>.
- ubuntu-22.04-arm
# - ubuntu-22.04-arm
# Disable `macOS-13` until
# <https://github.com/EnzymeAD/Reactant.jl/issues/867> is resolved.
# - macOS-13
- macOS-latest
- windows-latest
- linux-x86-ct6e-180-4tpu
# - macOS-latest
# - windows-latest
# - linux-x86-ct6e-180-4tpu
test_group:
- core
- neural_networks
- integration
# - neural_networks
# - integration
runtime:
- "pjrt"
# - "pjrt"
- "ifrt"
exclude:
- os: linux-x86-ct6e-180-4tpu
version: "1.10"
- os: linux-x86-ct6e-180-4tpu
runtime: "pjrt"
# exclude:
# - os: linux-x86-ct6e-180-4tpu
# version: "1.10"
# - os: linux-x86-ct6e-180-4tpu
# runtime: "pjrt"
uses: ./.github/workflows/CommonCI.yml
with:
julia_version: ${{ matrix.version }}
Expand All @@ -86,21 +86,21 @@ jobs:
# assertions: true
# test_group: ${{ matrix.test_group }}

downgrade:
strategy:
fail-fast: false
matrix:
test_group:
- core
- neural_networks
- integration
runtime:
- "pjrt"
- "ifrt"
uses: ./.github/workflows/CommonCI.yml
with:
julia_version: "1.10"
os: "ubuntu-24.04"
runtime: ${{ matrix.runtime }}
test_group: ${{ matrix.test_group }}
downgrade_testing: true
# downgrade:
# strategy:
# fail-fast: false
# matrix:
# test_group:
# - core
# - neural_networks
# - integration
# runtime:
# - "pjrt"
# - "ifrt"
# uses: ./.github/workflows/CommonCI.yml
# with:
# julia_version: "1.10"
# os: "ubuntu-24.04"
# runtime: ${{ matrix.runtime }}
# test_group: ${{ matrix.test_group }}
# downgrade_testing: true
6 changes: 5 additions & 1 deletion src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3564,12 +3564,16 @@ end
$(size(input, dimension)) (got $(lhs))"
@assert 0 rhs size(input, dimension) "rhs must be between 0 and \
$(size(input, dimension)) (got $(rhs))"

sz = collect(Int64, size(input))
sz[dimension] = sz[dimension] + lhs + rhs

return TracedRArray{T,N}(
(),
MLIR.IR.result(
enzymexla.wrap(input.mlir_data; lhs, rhs, dimension=dimension - 1, location), 1
),
size(input),
sz,
)
end

Expand Down
9 changes: 9 additions & 0 deletions test/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,15 @@ end
@test fr!(vr) f!(v)
end

fn_test_wrap(x) = Reactant.Ops.wrap(x, 2, 1; dimension=3)

@testset "Ops.wrap" begin
x = Reactant.to_rarray(rand(2, 3, 4, 5))
out = @jit fn_test_wrap(x)

@test size(out) == (2, 3, 7, 5)
end

@testset "Ops.fill" begin
@testset "Fill with TracedScalar" begin
fn(x) = Ops.fill(x, [2, 3])
Expand Down
16 changes: 16 additions & 0 deletions test/optimize_comm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ function dus2(x, y)
return nothing
end

function wrap(x)
return Reactant.Ops.@opcall wrap(x, 2, 2; dimension=1)
end

if length(addressable_devices) ≥ 8
@testset "Rotate" begin
N = min((length(Reactant.devices()) ÷ 2) * 2, 8)
Expand Down Expand Up @@ -108,4 +112,16 @@ if length(addressable_devices) ≥ 8
@test all(x .== convert(Array, rx))
@test all(y .== convert(Array, ry))
end

@testset "Wrap" begin
mesh = Sharding.Mesh(Reactant.devices(), (:x,))
sharding = Sharding.NamedSharding(mesh, (:x,))

x = Reactant.to_rarray(rand(192 * length(addressable_devices)); sharding)
@assert x isa ConcreteIFRTArray

@test !contains(hlo, "all-to-all")
@test !contains(hlo, "all-gather")
@test contains(hlo, "collective-permute")
end
end
130 changes: 65 additions & 65 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,72 +19,72 @@ if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"
end
end

@testset "Reactant.jl Tests" begin
if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "core"
if Sys.isapple() && haskey(Reactant.XLA.global_backend_state.clients, "metal")
@safetestset "Metal Plugin" include("plugins/metal.jl")
end
# @testset "Reactant.jl Tests" begin
# if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "core"
# if Sys.isapple() && haskey(Reactant.XLA.global_backend_state.clients, "metal")
# @safetestset "Metal Plugin" include("plugins/metal.jl")
# end

@safetestset "Layout" include("layout.jl")
@safetestset "Tracing" include("tracing.jl")
@safetestset "Basic" include("basic.jl")
@safetestset "Constructor" include("constructor.jl")
@safetestset "Autodiff" include("autodiff.jl")
@safetestset "Complex" include("complex.jl")
@safetestset "Broadcast" include("bcast.jl")
@safetestset "Struct" include("struct.jl")
@safetestset "Closure" include("closure.jl")
@safetestset "Compile" include("compile.jl")
@safetestset "IR" include("ir.jl")
@safetestset "Buffer Donation" include("buffer_donation.jl")
@safetestset "Wrapped Arrays" include("wrapped_arrays.jl")
@safetestset "Control Flow" include("control_flow.jl")
@safetestset "Sorting" include("sorting.jl")
@safetestset "Shortcuts to MLIR ops" include("ops.jl")
@safetestset "Indexing" include("indexing.jl")
@safetestset "Ranges" include("ranges.jl")
if !Sys.isapple()
@safetestset "Custom Number Types" include("custom_number_types.jl")
end
@safetestset "Sharding" include("sharding.jl")
@safetestset "Comm Optimization" include("optimize_comm.jl")
@safetestset "Cluster Detection" include("cluster_detector.jl")
@safetestset "Config" include("config.jl")
@safetestset "Batching" include("batching.jl")
@safetestset "QA" include("qa.jl")
end
# @safetestset "Layout" include("layout.jl")
# @safetestset "Tracing" include("tracing.jl")
# @safetestset "Basic" include("basic.jl")
# @safetestset "Constructor" include("constructor.jl")
# @safetestset "Autodiff" include("autodiff.jl")
# @safetestset "Complex" include("complex.jl")
# @safetestset "Broadcast" include("bcast.jl")
# @safetestset "Struct" include("struct.jl")
# @safetestset "Closure" include("closure.jl")
# @safetestset "Compile" include("compile.jl")
# @safetestset "IR" include("ir.jl")
# @safetestset "Buffer Donation" include("buffer_donation.jl")
# @safetestset "Wrapped Arrays" include("wrapped_arrays.jl")
# @safetestset "Control Flow" include("control_flow.jl")
# @safetestset "Sorting" include("sorting.jl")
# @safetestset "Shortcuts to MLIR ops" include("ops.jl")
# @safetestset "Indexing" include("indexing.jl")
# @safetestset "Ranges" include("ranges.jl")
# if !Sys.isapple()
# @safetestset "Custom Number Types" include("custom_number_types.jl")
# end
# @safetestset "Sharding" include("sharding.jl")
@safetestset "Comm Optimization" include("optimize_comm.jl")
# @safetestset "Cluster Detection" include("cluster_detector.jl")
# @safetestset "Config" include("config.jl")
# @safetestset "Batching" include("batching.jl")
# @safetestset "QA" include("qa.jl")
# end

if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"
@safetestset "CUDA" include("integration/cuda.jl")
@safetestset "KernelAbstractions" include("integration/kernelabstractions.jl")
@safetestset "Linear Algebra" include("integration/linear_algebra.jl")
@safetestset "OffsetArrays" include("integration/offsetarrays.jl")
@safetestset "OneHotArrays" include("integration/onehotarrays.jl")
@safetestset "AbstractFFTs" include("integration/fft.jl")
@safetestset "SpecialFunctions" include("integration/special_functions.jl")
@safetestset "Random" include("integration/random.jl")
@safetestset "Python" include("integration/python.jl")
@safetestset "Optimisers" include("integration/optimisers.jl")
@safetestset "FillArrays" include("integration/fillarrays.jl")
if ENZYMEJAX_INSTALLED[] && !Sys.isapple()
@safetestset "EnzymeJAX Export" include("integration/enzymejax.jl")
end
@safetestset "MPI" begin
using MPI
nranks = 2
run(`$(mpiexec()) -n $nranks $(Base.julia_cmd()) integration/mpi.jl`)
end
# if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"
# @safetestset "CUDA" include("integration/cuda.jl")
# @safetestset "KernelAbstractions" include("integration/kernelabstractions.jl")
# @safetestset "Linear Algebra" include("integration/linear_algebra.jl")
# @safetestset "OffsetArrays" include("integration/offsetarrays.jl")
# @safetestset "OneHotArrays" include("integration/onehotarrays.jl")
# @safetestset "AbstractFFTs" include("integration/fft.jl")
# @safetestset "SpecialFunctions" include("integration/special_functions.jl")
# @safetestset "Random" include("integration/random.jl")
# @safetestset "Python" include("integration/python.jl")
# @safetestset "Optimisers" include("integration/optimisers.jl")
# @safetestset "FillArrays" include("integration/fillarrays.jl")
# if ENZYMEJAX_INSTALLED[] && !Sys.isapple()
# @safetestset "EnzymeJAX Export" include("integration/enzymejax.jl")
# end
# @safetestset "MPI" begin
# using MPI
# nranks = 2
# run(`$(mpiexec()) -n $nranks $(Base.julia_cmd()) integration/mpi.jl`)
# end

# Zygote is not supported on 1.12 https://github.com/FluxML/Zygote.jl/issues/1580
if VERSION < v"1.12-"
@safetestset "Zygote" include("integration/zygote.jl")
end
end
# # Zygote is not supported on 1.12 https://github.com/FluxML/Zygote.jl/issues/1580
# if VERSION < v"1.12-"
# @safetestset "Zygote" include("integration/zygote.jl")
# end
# end

if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"
@safetestset "NNlib Primitives" include("nn/nnlib.jl")
@safetestset "Flux.jl Integration" include("nn/flux.jl")
@safetestset "LuxLib Primitives" include("nn/luxlib.jl")
@safetestset "Lux Integration" include("nn/lux.jl")
end
end
# if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"
# @safetestset "NNlib Primitives" include("nn/nnlib.jl")
# @safetestset "Flux.jl Integration" include("nn/flux.jl")
# @safetestset "LuxLib Primitives" include("nn/luxlib.jl")
# @safetestset "Lux Integration" include("nn/lux.jl")
# end
# end
Loading