Skip to content

Commit 6ad2d04

Browse files
test: fold in Lux tests
1 parent 674eb2a commit 6ad2d04

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

test/extensions/lux.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,31 +22,31 @@ using ComponentArrays
2222
@test out isa Symbolics.Arr
2323
@test length(out) == 6
2424
# test that we can recover the same value as when using concrete numbers
25-
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x),)))
25+
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x),); fold = Val(true)))
2626
@test out_sub == out_ref
2727

2828
out = LuxCore.stateless_apply(model, sym_x, sym_ps)
2929
@test out isa Symbolics.Arr
3030
@test length(out) == 6
31-
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps => ps),)))
31+
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps => ps),); fold = Val(true)))
3232
@test out_sub == out_ref
3333

3434
out = LuxCore.stateless_apply(model, sym_x, ca)
3535
@test out isa Symbolics.Arr
3636
@test length(out) == 6
37-
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x,),)))
37+
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x,),); fold = Val(true)))
3838
@test out_sub == out_ref
3939

4040
out = LuxCore.stateless_apply(model, sym_x, sym_ca)
4141
@test out isa Symbolics.Arr
4242
@test length(out) == 6
43-
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca),)))
43+
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca),); fold = Val(true)))
4444
@test out_sub == out_ref
4545

4646
out = LuxCore.stateless_apply(sym_model, sym_x, sym_ca)
4747
@test out isa Symbolics.Arr
4848
@test length(out) == 6
49-
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_model => model, sym_x => x, sym_ca => ca),)))
49+
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_model => model, sym_x => x, sym_ca => ca),); fold = Val(true)))
5050
@test out_sub == out_ref
5151
end
5252

@@ -71,30 +71,30 @@ end
7171
@test out isa Symbolics.Arr
7272
@test length(out) == 3
7373
# test that we can recover the same value as when using concrete numbers
74-
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x),)))
74+
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x),); fold = Val(true)))
7575
@test out_sub == out_ref
7676

7777
out = LuxCore.stateless_apply(model, sym_x, sym_ps)
7878
@test out isa Symbolics.Arr
7979
@test length(out) == 3
80-
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps => ps),)))
80+
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps => ps),); fold = Val(true)))
8181
@test out_sub == out_ref
8282

8383
out = LuxCore.stateless_apply(model, sym_x, ca)
8484
@test out isa Symbolics.Arr
8585
@test length(out) == 3
86-
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x,),)))
86+
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x,),); fold = Val(true)))
8787
@test out_sub == out_ref
8888

8989
out = LuxCore.stateless_apply(model, sym_x, sym_ca)
9090
@test out isa Symbolics.Arr
9191
@test length(out) == 3
92-
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca),)))
92+
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca),); fold = Val(true)))
9393
@test out_sub == out_ref
9494

9595
out = LuxCore.stateless_apply(sym_model, sym_x, sym_ca)
9696
@test out isa Symbolics.Arr
9797
@test length(out) == 3
98-
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_model => model, sym_x => x, sym_ca => ca),)))
98+
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_model => model, sym_x => x, sym_ca => ca),); fold = Val(true)))
9999
@test out_sub == out_ref
100100
end

0 commit comments

Comments
 (0)