926926
927927 ra = Reactant. to_rarray (x)
928928 @jit dip! (ra)
929- ra[:a ] ≈ (2.7 * 2 ) * ones (4 )
929+ @test ra[:a ] ≈ (2.7 * 3.1 ) * ones (4 )
930930end
931931
932932@testset " @code_xla" begin
@@ -1429,7 +1429,10 @@ end
14291429end
14301430
14311431zip_iterator (a, b) = mapreduce (splat (* ), + , zip (a, b))
1432+ zip_iterator2 (a, b) = mapreduce (splat (.- ), + , zip (a, b))
14321433enumerate_iterator (a) = mapreduce (splat (* ), + , enumerate (a))
1434+ enumerate_iterator2 (a) = mapreduce (splat (.- ), + , enumerate (a))
1435+ mapreduce_vector (a) = mapreduce (- , + , a)
14331436
14341437function nested_mapreduce_zip (x, y)
14351438 return mapreduce (+ , zip (eachcol (x), eachcol (y)); init= 0.0f0 ) do (x, y)
@@ -1448,18 +1451,30 @@ end
14481451@testset " Base.Iterators" begin
14491452 @testset " zip" begin
14501453 N = 10
1451- a = range (1.0 , 5.0 ; length= N)
1452- x = range (10.0 , 15.0 ; length= N + 2 )
1454+ a = collect ( range (1.0 , 5.0 ; length= N) )
1455+ x = collect ( range (10.0 , 15.0 ; length= N + 2 ) )
14531456 x_ra = Reactant. to_rarray (x)
14541457
14551458 @test @jit (zip_iterator (a, x_ra)) ≈ zip_iterator (a, x)
1459+
1460+ a = [rand (Float32, 2 , 3 ) for _ in 1 : 10 ]
1461+ x = [rand (Float32, 2 , 3 ) for _ in 1 : 10 ]
1462+ a_ra = Reactant. to_rarray (a)
1463+ x_ra = Reactant. to_rarray (x)
1464+
1465+ @test @jit (zip_iterator2 (a_ra, x_ra)) ≈ zip_iterator2 (a, x)
14561466 end
14571467
14581468 @testset " enumerate" begin
1459- x = range (1.0 , 5.0 ; length= 10 )
1469+ x = collect ( range (1.0 , 5.0 ; length= 10 ) )
14601470 x_ra = Reactant. to_rarray (x)
14611471
14621472 @test @jit (enumerate_iterator (x_ra)) ≈ enumerate_iterator (x)
1473+
1474+ x = [rand (Float32, 2 , 3 ) for _ in 1 : 10 ]
1475+ x_ra = Reactant. to_rarray (x)
1476+
1477+ @test @jit (enumerate_iterator2 (x_ra)) ≈ enumerate_iterator2 (x)
14631478 end
14641479
14651480 @testset " nested mapreduce" begin
@@ -1481,6 +1496,15 @@ end
14811496
14821497 @test @jit (nested_mapreduce_hcat (x_ra, y_ra)) ≈ nested_mapreduce_hcat (x, y)
14831498 end
1499+
1500+ @testset " mapreduce vector" begin
1501+ x = [rand (Float32, 2 , 3 ) for _ in 1 : 10 ]
1502+ x_ra = Reactant. to_rarray (x)
1503+
1504+ @test @jit (mapreduce_vector (x_ra)) ≈ mapreduce_vector (x)
1505+ hlo = repr (@code_hlo optimize = false mapreduce_vector (x_ra))
1506+ @test contains (hlo, " call" )
1507+ end
14841508end
14851509
14861510@testset " compilation cache" begin
0 commit comments