diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 23075cdb56..5ebff3931f 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023" NSYNC_SHA256 = "" -ENZYMEXLA_COMMIT = "ddea75dacb5b9ba708e0b44a0161d4acf9adae31" +ENZYMEXLA_COMMIT = "04c15dc7c3736cca6f0a93c9934829fde489cc11" ENZYMEXLA_SHA256 = "" diff --git a/test/batching.jl b/test/batching.jl index 922739d5b2..fa0b577f5c 100644 --- a/test/batching.jl +++ b/test/batching.jl @@ -49,7 +49,11 @@ function run_auto_batching_tests(f::F, args...) where {F} ) @test occursin("stablehlo.while", hlo) - hlo = repr(@code_hlo f(args...)) + hlo = repr( + @code_hlo compile_options = CompileOptions(; + disable_auto_batching_passes=false + ) f(args...) + ) @test !occursin("stablehlo.while", hlo) end end @@ -115,9 +119,13 @@ end input1 = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float32, 10)) input2 = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float32, 10)) - hlo = @code_hlo optimize = false mctr(map_with_scalar_indexing, 1:8, input1, input2) + hlo = @code_hlo compile_options = CompileOptions(; disable_auto_batching_passes=true) mctr( + map_with_scalar_indexing, 1:8, input1, input2 + ) @test contains(repr(hlo), "stablehlo.while") - hlo = @code_hlo optimize = true mctr(map_with_scalar_indexing, 1:8, input1, input2) + hlo = @code_hlo compile_options = CompileOptions(; disable_auto_batching_passes=false) mctr( + map_with_scalar_indexing, 1:8, input1, input2 + ) @test !contains(repr(hlo), "stablehlo.while") res_ra = @jit mctr(map_with_scalar_indexing, 1:8, input1, input2)