Skip to content

Commit

Permalink
Update tests wrt. virtual nodes optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed May 6, 2023
1 parent c780082 commit f35e50f
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 36 deletions.
4 changes: 0 additions & 4 deletions lib/code.ml
Original file line number Diff line number Diff line change
Expand Up @@ -812,10 +812,6 @@ let virtual_llc node_store reverse_node_map (llc : unit low_level) : unit low_le
in
loop_proc ~process_for:Set.Poly.empty llc

module Debug_runtime = Minidebug_runtime.Flushing (struct
let debug_ch = Stdio.stdout
end)

let cleanup_virtual_llc node_store reverse_node_map (llc : unit low_level) : unit low_level =
let is_inline tensor =
let node = Hashtbl.find_exn node_store tensor in
Expand Down
1 change: 0 additions & 1 deletion lib/nodeUI.ml
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ let retrieve_2d_points ?from_axis ~xdim ~ydim arr =
iter 0;
Array.of_list_rev !result

(* module Debug_runtime = Minidebug_runtime.Flushing(struct let debug_ch = Stdio.stdout end) *)
let retrieve_1d_points ?from_axis ~xdim arr =
let dims = N.dims arr in
if Array.is_empty dims then [||]
Expand Down
4 changes: 3 additions & 1 deletion test/einsum_trivia.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1215,7 +1215,9 @@ let%expect_test "einsum1 fixed dim axis" =
└────────────────────────┘ |}];
let%nn_op ho4 = hey2 ++ "i->j => i0j" in
print_formula ~with_code:false ~with_grad:false `Default @@ ho4;
[%expect {| <void> |}]
[%expect {|
[6]: ho4 <=>> shape 0:2,1:1,2:3
<void> |}]

let%expect_test "einsum with fixed dim axes" =
let open Session.SDSL in
Expand Down
61 changes: 31 additions & 30 deletions test/zero2hero_1of7.ml
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,26 @@ let%expect_test "Graph drawing recompile" =
SDSL.print_node_tree ~with_grad:true ~depth:9 f.id;
[%expect
{|
[13] f <+>
6.00e+1
Gradient
1.00e+0
[12] <+> │[2] <5>
5.50e+1 5.00e+0
Gradient
1.00e+0
[9] <*.> │ [11] <*.>
7.50e+1-2.00e+1
GradientGradient
1.00e+01.00e+0
[8] <3> │ [6] <**.> │[10] <-1> │ [4] <*.>
3.00e+0 2.50e+1 -1.00e+0 2.00e+1
Gradient │ │ Gradient
3.00e+0 │ │ -1.00e+0
│[1]│[5] <2> │ │[3] <4> │[1] <x>
│ │ 2.00e+0 4.00e+05.00e+0
│ │ │ Gradient
│ │ │ 2.60e+1|}];
[13] f <+>
6.00e+1
Gradient
1.00e+0
[12] <+> │[2] <5> virtual
5.50e+1 <void>
Gradient
1.00e+0
[9] <*.> [11] <*.>
7.50e+1 -2.00e+1
Gradient Gradient
1.00e+0 1.00e+0
[8] <3> virtual [6] <**.> │[10] <-1> virtual [4] <*.>
<void> 2.50e+1 <void> 2.00e+1
Gradient Gradient
3.00e+0 -1.00e+0
│[1]│[5] <2> virtual │[3] <4> virtual│[1] <x>
│ │<void> <void> 5.00e+0
│ │ Gradient
│ │ 2.60e+1|}];
let xs = Array.init 10 ~f:Float.(fun i -> of_int i - 5.) in
let ys =
Array.map xs ~f:(fun v ->
Expand Down Expand Up @@ -96,22 +96,23 @@ let%expect_test "Graph drawing fetch" =
let open SDSL.O in
SDSL.drop_all_sessions ();
Random.init 0;
CDSL.debug_virtual_nodes := true;
let%nn_op f x = (3 *. (x **. 2)) - (4 *. x) + 5 in
let%nn_op f5 = f 5 in
SDSL.refresh_session ();
SDSL.print_node_tree ~with_grad:false ~depth:9 f5.id;
[%expect
{|
[12] f <+>
6.00e+1
[11] <+> │[2] <5>
5.50e+15.00e+0
[8] <*.> │ [10] <*.>
7.50e+1-2.00e+1
[7] <3> [6] <**.> │[9] <-1> │ [4] <*.>
3.00e+0 2.50e+1-1.00e+02.00e+1
│[1]│[5] <2> │ │[3] <4> │[1] <5>
│ │ 2.00e+0 │ │ 4.00e+05.00e+0|}];
[12] f <+>
6.00e+1
[11] <+> virtual │[2] <5> virtual
5.50e+1 5.00e+0
[8] <*.> virtual [10] <*.> virtual
7.50e+1 -2.00e+1
[7] <3> virtual[6] <**.> virtual │[9] <-1> virtual [4] <*.> virtual
3.00e+0 2.50e+1 -1.00e+0 2.00e+1
│[1]│[5] <2> virtual │[3] <4> virtual│[1] <5> virtual
│ │ 2.00e+0 4.00e+0 5.00e+0 |}];
(* close_session is not necessary. *)
SDSL.close_session ();
let size = 100 in
Expand Down

0 comments on commit f35e50f

Please sign in to comment.