diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 064ef33235..9032452c2b 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -1,52 +1,52 @@ -steps: - - group: ":test_tube: Tests" - steps: - - label: ":julia: :linux: CUDA Julia v{{matrix.version}} -- {{matrix.group}} -- {{matrix.runtime}}" - matrix: - setup: - version: - - "1.10" - group: - - core - - neural_networks - - integration - runtime: - - "PJRT" - - "IFRT" - plugins: - - JuliaCI/julia#v1: - version: "{{matrix.version}}" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - - lib/ReactantCore/src - commands: | - touch LocalPreferences.toml +# steps: +# - group: ":test_tube: Tests" +# steps: +# - label: ":julia: :linux: CUDA Julia v{{matrix.version}} -- {{matrix.group}} -- {{matrix.runtime}}" +# matrix: +# setup: +# version: +# - "1.10" +# group: +# - core +# - neural_networks +# - integration +# runtime: +# - "PJRT" +# - "IFRT" +# plugins: +# - JuliaCI/julia#v1: +# version: "{{matrix.version}}" +# - JuliaCI/julia-coverage#v1: +# codecov: true +# dirs: +# - src +# - ext +# - lib/ReactantCore/src +# commands: | +# touch LocalPreferences.toml - echo "[Reactant]" >> LocalPreferences.toml - echo "xla_runtime = \"{{matrix.runtime}}\"" >> LocalPreferences.toml +# echo "[Reactant]" >> LocalPreferences.toml +# echo "xla_runtime = \"{{matrix.runtime}}\"" >> LocalPreferences.toml - cat LocalPreferences.toml +# cat LocalPreferences.toml - julia --project=. -e 'println("--- :julia: Instantiating project") - using Pkg - Pkg.develop([PackageSpec(path="lib/ReactantCore")])' +# julia --project=. -e 'println("--- :julia: Instantiating project") +# using Pkg +# Pkg.develop([PackageSpec(path="lib/ReactantCore")])' - julia --project=. -e 'println("--- :julia: Run Tests") - using Pkg - Pkg.test(; coverage="user")' - agents: - queue: "juliagpu" - cuda: "*" - env: - REACTANT_TEST_GROUP: "{{matrix.group}}" - JULIA_DEBUG: "Reactant,Reactant_jll" - CUDA_VISIBLE_DEVICES: 0 - REACTANT_BACKEND_GROUP: "GPU" - if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 120 +# julia --project=. -e 'println("--- :julia: Run Tests") +# using Pkg +# Pkg.test(; coverage="user")' +# agents: +# queue: "juliagpu" +# cuda: "*" +# env: +# REACTANT_TEST_GROUP: "{{matrix.group}}" +# JULIA_DEBUG: "Reactant,Reactant_jll" +# CUDA_VISIBLE_DEVICES: 0 +# REACTANT_BACKEND_GROUP: "GPU" +# if: build.message !~ /\[skip tests\]/ +# timeout_in_minutes: 120 # - label: ":julia: :linux: AMDGPU Julia v{{matrix.version}} -- {{matrix.group}} -- {{matrix.runtime}}" # matrix: diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index b60f801cbc..f3aff51422 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -34,32 +34,32 @@ jobs: fail-fast: false matrix: version: - - "1.10" + # - "1.10" - "1.11" # - 'nightly' os: - - ubuntu-24.04 + - ubuntu-latest # `ubuntu-22.04-arm` is considered more stable than `ubuntu-24.04-arm`: # . - - ubuntu-22.04-arm + # - ubuntu-22.04-arm # Disable `macOS-13` until # is resolved. # - macOS-13 - - macOS-latest - - windows-latest - - linux-x86-ct6e-180-4tpu + # - macOS-latest + # - windows-latest + # - linux-x86-ct6e-180-4tpu test_group: - core - - neural_networks - - integration + # - neural_networks + # - integration runtime: - - "pjrt" + # - "pjrt" - "ifrt" - exclude: - - os: linux-x86-ct6e-180-4tpu - version: "1.10" - - os: linux-x86-ct6e-180-4tpu - runtime: "pjrt" + # exclude: + # - os: linux-x86-ct6e-180-4tpu + # version: "1.10" + # - os: linux-x86-ct6e-180-4tpu + # runtime: "pjrt" uses: ./.github/workflows/CommonCI.yml with: julia_version: ${{ matrix.version }} @@ -86,21 +86,21 @@ jobs: # assertions: true # test_group: ${{ matrix.test_group }} - downgrade: - strategy: - fail-fast: false - matrix: - test_group: - - core - - neural_networks - - integration - runtime: - - "pjrt" - - "ifrt" - uses: ./.github/workflows/CommonCI.yml - with: - julia_version: "1.10" - os: "ubuntu-24.04" - runtime: ${{ matrix.runtime }} - test_group: ${{ matrix.test_group }} - downgrade_testing: true + # downgrade: + # strategy: + # fail-fast: false + # matrix: + # test_group: + # - core + # - neural_networks + # - integration + # runtime: + # - "pjrt" + # - "ifrt" + # uses: ./.github/workflows/CommonCI.yml + # with: + # julia_version: "1.10" + # os: "ubuntu-24.04" + # runtime: ${{ matrix.runtime }} + # test_group: ${{ matrix.test_group }} + # downgrade_testing: true diff --git a/src/Ops.jl b/src/Ops.jl index ed162501ba..75e4c659cb 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -3564,12 +3564,16 @@ end $(size(input, dimension)) (got $(lhs))" @assert 0 ≤ rhs ≤ size(input, dimension) "rhs must be between 0 and \ $(size(input, dimension)) (got $(rhs))" + + sz = collect(Int64, size(input)) + sz[dimension] = sz[dimension] + lhs + rhs + return TracedRArray{T,N}( (), MLIR.IR.result( enzymexla.wrap(input.mlir_data; lhs, rhs, dimension=dimension - 1, location), 1 ), - size(input), + sz, ) end diff --git a/test/ops.jl b/test/ops.jl index eba91db228..c5edf06493 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -1138,6 +1138,15 @@ end @test fr!(vr) ≈ f!(v) end +fn_test_wrap(x) = Reactant.Ops.wrap(x, 2, 1; dimension=3) + +@testset "Ops.wrap" begin + x = Reactant.to_rarray(rand(2, 3, 4, 5)) + out = @jit fn_test_wrap(x) + + @test size(out) == (2, 3, 7, 5) +end + @testset "Ops.fill" begin @testset "Fill with TracedScalar" begin fn(x) = Ops.fill(x, [2, 3]) diff --git a/test/optimize_comm.jl b/test/optimize_comm.jl index 404ccaa3a0..dae7a3afa8 100644 --- a/test/optimize_comm.jl +++ b/test/optimize_comm.jl @@ -22,6 +22,10 @@ function dus2(x, y) return nothing end +function wrap(x) + return Reactant.Ops.@opcall wrap(x, 2, 2; dimension=1) +end + if length(addressable_devices) ≥ 8 @testset "Rotate" begin N = min((length(Reactant.devices()) ÷ 2) * 2, 8) @@ -108,4 +112,16 @@ if length(addressable_devices) ≥ 8 @test all(x .== convert(Array, rx)) @test all(y .== convert(Array, ry)) end + + @testset "Wrap" begin + mesh = Sharding.Mesh(Reactant.devices(), (:x,)) + sharding = Sharding.NamedSharding(mesh, (:x,)) + + x = Reactant.to_rarray(rand(192 * length(addressable_devices)); sharding) + @assert x isa ConcreteIFRTArray + + @test !contains(hlo, "all-to-all") + @test !contains(hlo, "all-gather") + @test contains(hlo, "collective-permute") + end end diff --git a/test/runtests.jl b/test/runtests.jl index bbd5e0855f..20b1735d21 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,72 +19,72 @@ if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" end end -@testset "Reactant.jl Tests" begin - if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "core" - if Sys.isapple() && haskey(Reactant.XLA.global_backend_state.clients, "metal") - @safetestset "Metal Plugin" include("plugins/metal.jl") - end +# @testset "Reactant.jl Tests" begin +# if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "core" +# if Sys.isapple() && haskey(Reactant.XLA.global_backend_state.clients, "metal") +# @safetestset "Metal Plugin" include("plugins/metal.jl") +# end - @safetestset "Layout" include("layout.jl") - @safetestset "Tracing" include("tracing.jl") - @safetestset "Basic" include("basic.jl") - @safetestset "Constructor" include("constructor.jl") - @safetestset "Autodiff" include("autodiff.jl") - @safetestset "Complex" include("complex.jl") - @safetestset "Broadcast" include("bcast.jl") - @safetestset "Struct" include("struct.jl") - @safetestset "Closure" include("closure.jl") - @safetestset "Compile" include("compile.jl") - @safetestset "IR" include("ir.jl") - @safetestset "Buffer Donation" include("buffer_donation.jl") - @safetestset "Wrapped Arrays" include("wrapped_arrays.jl") - @safetestset "Control Flow" include("control_flow.jl") - @safetestset "Sorting" include("sorting.jl") - @safetestset "Shortcuts to MLIR ops" include("ops.jl") - @safetestset "Indexing" include("indexing.jl") - @safetestset "Ranges" include("ranges.jl") - if !Sys.isapple() - @safetestset "Custom Number Types" include("custom_number_types.jl") - end - @safetestset "Sharding" include("sharding.jl") - @safetestset "Comm Optimization" include("optimize_comm.jl") - @safetestset "Cluster Detection" include("cluster_detector.jl") - @safetestset "Config" include("config.jl") - @safetestset "Batching" include("batching.jl") - @safetestset "QA" include("qa.jl") - end +# @safetestset "Layout" include("layout.jl") +# @safetestset "Tracing" include("tracing.jl") +# @safetestset "Basic" include("basic.jl") +# @safetestset "Constructor" include("constructor.jl") +# @safetestset "Autodiff" include("autodiff.jl") +# @safetestset "Complex" include("complex.jl") +# @safetestset "Broadcast" include("bcast.jl") +# @safetestset "Struct" include("struct.jl") +# @safetestset "Closure" include("closure.jl") +# @safetestset "Compile" include("compile.jl") +# @safetestset "IR" include("ir.jl") +# @safetestset "Buffer Donation" include("buffer_donation.jl") +# @safetestset "Wrapped Arrays" include("wrapped_arrays.jl") +# @safetestset "Control Flow" include("control_flow.jl") +# @safetestset "Sorting" include("sorting.jl") +# @safetestset "Shortcuts to MLIR ops" include("ops.jl") +# @safetestset "Indexing" include("indexing.jl") +# @safetestset "Ranges" include("ranges.jl") +# if !Sys.isapple() +# @safetestset "Custom Number Types" include("custom_number_types.jl") +# end +# @safetestset "Sharding" include("sharding.jl") +@safetestset "Comm Optimization" include("optimize_comm.jl") +# @safetestset "Cluster Detection" include("cluster_detector.jl") +# @safetestset "Config" include("config.jl") +# @safetestset "Batching" include("batching.jl") +# @safetestset "QA" include("qa.jl") +# end - if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" - @safetestset "CUDA" include("integration/cuda.jl") - @safetestset "KernelAbstractions" include("integration/kernelabstractions.jl") - @safetestset "Linear Algebra" include("integration/linear_algebra.jl") - @safetestset "OffsetArrays" include("integration/offsetarrays.jl") - @safetestset "OneHotArrays" include("integration/onehotarrays.jl") - @safetestset "AbstractFFTs" include("integration/fft.jl") - @safetestset "SpecialFunctions" include("integration/special_functions.jl") - @safetestset "Random" include("integration/random.jl") - @safetestset "Python" include("integration/python.jl") - @safetestset "Optimisers" include("integration/optimisers.jl") - @safetestset "FillArrays" include("integration/fillarrays.jl") - if ENZYMEJAX_INSTALLED[] && !Sys.isapple() - @safetestset "EnzymeJAX Export" include("integration/enzymejax.jl") - end - @safetestset "MPI" begin - using MPI - nranks = 2 - run(`$(mpiexec()) -n $nranks $(Base.julia_cmd()) integration/mpi.jl`) - end +# if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" +# @safetestset "CUDA" include("integration/cuda.jl") +# @safetestset "KernelAbstractions" include("integration/kernelabstractions.jl") +# @safetestset "Linear Algebra" include("integration/linear_algebra.jl") +# @safetestset "OffsetArrays" include("integration/offsetarrays.jl") +# @safetestset "OneHotArrays" include("integration/onehotarrays.jl") +# @safetestset "AbstractFFTs" include("integration/fft.jl") +# @safetestset "SpecialFunctions" include("integration/special_functions.jl") +# @safetestset "Random" include("integration/random.jl") +# @safetestset "Python" include("integration/python.jl") +# @safetestset "Optimisers" include("integration/optimisers.jl") +# @safetestset "FillArrays" include("integration/fillarrays.jl") +# if ENZYMEJAX_INSTALLED[] && !Sys.isapple() +# @safetestset "EnzymeJAX Export" include("integration/enzymejax.jl") +# end +# @safetestset "MPI" begin +# using MPI +# nranks = 2 +# run(`$(mpiexec()) -n $nranks $(Base.julia_cmd()) integration/mpi.jl`) +# end - # Zygote is not supported on 1.12 https://github.com/FluxML/Zygote.jl/issues/1580 - if VERSION < v"1.12-" - @safetestset "Zygote" include("integration/zygote.jl") - end - end +# # Zygote is not supported on 1.12 https://github.com/FluxML/Zygote.jl/issues/1580 +# if VERSION < v"1.12-" +# @safetestset "Zygote" include("integration/zygote.jl") +# end +# end - if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks" - @safetestset "NNlib Primitives" include("nn/nnlib.jl") - @safetestset "Flux.jl Integration" include("nn/flux.jl") - @safetestset "LuxLib Primitives" include("nn/luxlib.jl") - @safetestset "Lux Integration" include("nn/lux.jl") - end -end +# if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks" +# @safetestset "NNlib Primitives" include("nn/nnlib.jl") +# @safetestset "Flux.jl Integration" include("nn/flux.jl") +# @safetestset "LuxLib Primitives" include("nn/luxlib.jl") +# @safetestset "Lux Integration" include("nn/lux.jl") +# end +# end