diff --git a/lib/axon/compiler.ex b/lib/axon/compiler.ex index a87f9dda..00149bee 100644 --- a/lib/axon/compiler.ex +++ b/lib/axon/compiler.ex @@ -480,19 +480,16 @@ defmodule Axon.Compiler do &to_model_funs(&1, nodes, &2, config) ) - input_nodes = get_input_subgraph_nodes(parent, nodes, %{}) - input_subgraph = %Axon{output: parent, nodes: input_nodes} - - {{block_init_fun, block_predict_fun}, init_block?, block_name, block_cache, op_counts} = + {{block_init_fun, block_predict_fun}, block_name, block_cache, op_counts} = case block_cache do %{^block_id => {funs, name}} = block_cache -> - {funs, false, name, block_cache, op_counts} + {funs, name, block_cache, op_counts} %{} -> - funs = build(block_fun.(input_subgraph), debug?: config.debug?) + funs = build(block_fun.(Axon.input("subgraph")), debug?: config.debug?) name = name_fn.(:block, op_counts) op_counts = Map.update(op_counts, :block, 1, fn x -> x + 1 end) - {funs, true, name, Map.put(block_cache, block_id, {funs, name}), op_counts} + {funs, name, Map.put(block_cache, block_id, {funs, name}), op_counts} end predict_fun = fn params, inputs, state, cache, result_cache, fn_stacktrace -> @@ -547,7 +544,7 @@ defmodule Axon.Compiler do end init_fun = fn template, cache, result_cache, fn_stacktrace, keys -> - {[_parent_shape], {parent_params, result_cache, none?}} = + {[parent_shape], {parent_params, result_cache, none?}} = Enum.map_reduce([parent_id], {%{}, result_cache, false}, fn parent_id, {params, result_cache, none?} -> {parent_shape, {params, result_cache}} = @@ -568,13 +565,14 @@ defmodule Axon.Compiler do if none? do {%Axon.None{}, {parent_params, result_cache}} else + template = Nx.broadcast(0.0, parent_shape) block_params = apply(block_init_fun, [template, %{}]) params = if block_params == %{} do %{} else - %{block_name => block_params} + Map.put(parent_params, block_name, block_params) end {pred_expr, {_, result_cache}} = @@ -1094,19 +1092,4 @@ defmodule Axon.Compiler do defp propagating_none?(_), do: false defp us_to_ms(time), do: Float.round(time / 1000, 1) - - defp get_input_subgraph_nodes(parent_id, nodes, acc) do - case nodes do - %{^parent_id => %{parent: parents} = node} -> - acc = - Enum.reduce(parents, acc, fn id, acc -> - get_input_subgraph_nodes(id, nodes, acc) - end) - - Map.put(acc, parent_id, node) - - %{} -> - acc - end - end end diff --git a/test/axon/compiler_test.exs b/test/axon/compiler_test.exs index 98066421..7eee9c23 100644 --- a/test/axon/compiler_test.exs +++ b/test/axon/compiler_test.exs @@ -5253,7 +5253,7 @@ defmodule CompilerTest do assert Nx.shape(k1) == {1, 32} assert Nx.shape(b1) == {32} - assert Nx.shape(k2) == {1, 32} + assert Nx.shape(k2) == {32, 32} assert Nx.shape(b2) == {32} assert Nx.type(k1) == {:f, 32} assert Nx.type(b1) == {:f, 32} @@ -5301,6 +5301,122 @@ defmodule CompilerTest do assert map_size(params) == 1 end + test "initializes correctly with multiple blocks in network" do + block1 = Axon.block(&Axon.dense(&1, 32)) + block2 = Axon.block(&Axon.dense(&1, 32)) + + model = + Axon.input("features") + |> block1.() + |> block2.() + + {init_fn, _} = Axon.build(model) + + assert %{ + "block_0" => + %{ + "dense_0" => %{"kernel" => k1, "bias" => b1} + } = block_0_params, + "block_1" => + %{ + "dense_0" => %{"kernel" => k2, "bias" => b2} + } = block_1_params + } = params = init_fn.(Nx.template({1, 1}, :f32), %{}) + + assert Nx.shape(k1) == {1, 32} + assert Nx.shape(b1) == {32} + assert Nx.shape(k2) == {32, 32} + assert Nx.shape(b2) == {32} + assert Nx.type(k1) == {:f, 32} + assert Nx.type(b1) == {:f, 32} + assert Nx.type(k2) == {:f, 32} + assert Nx.type(b2) == {:f, 32} + + # no additional dense layers in block + assert map_size(block_0_params) == 1 + assert map_size(block_1_params) == 1 + # no additional blocks + assert map_size(params) == 2 + end + + test "initializes correctly with block inside of a block" do + block = + Axon.block(fn x -> + inner_block = Axon.block(&Axon.dense(&1, 1)) + + x |> inner_block.() |> inner_block.() + end) + + model = + Axon.input("features") + |> block.() + |> block.() + + {init_fn, _} = Axon.build(model) + + assert %{ + "block_0" => + %{ + "block_0" => %{"dense_0" => %{"kernel" => k, "bias" => b}} = inner_block_params + } = block_params + } = params = init_fn.(Nx.template({1, 1}, :f32), %{}) + + assert Nx.shape(k) == {1, 1} + assert Nx.shape(b) == {1} + assert Nx.type(k) == {:f, 32} + assert Nx.type(b) == {:f, 32} + + assert map_size(inner_block_params) == 1 + assert map_size(block_params) == 1 + assert map_size(params) == 1 + end + + test "initializes correctly when using block from outside scope" do + block1 = Axon.block(&Axon.dense(&1, 32)) + + block2 = + Axon.block(fn x -> + x |> Axon.dense(32) |> block1.() + end) + + model = + Axon.input("features") + |> Axon.dense(32) + |> block1.() + |> block2.() + |> block1.() + |> block2.() + + {init_fn, _} = Axon.build(model) + + assert %{ + "block_0" => block_0_params, + "block_1" => block_1_params, + "dense_0" => dense_0_params + } = params = init_fn.(Nx.template({1, 1}, :f32), %{}) + + assert %{"dense_0" => %{"kernel" => b0k, "bias" => b0b}} = block_0_params + assert %{"dense_0" => %{"kernel" => b1k, "bias" => b1b}} = block_1_params + assert %{"kernel" => k1, "bias" => b1} = dense_0_params + + assert Nx.shape(b0k) == {32, 32} + assert Nx.shape(b1k) == {32, 32} + assert Nx.shape(b0b) == {32} + assert Nx.shape(b1b) == {32} + assert Nx.shape(k1) == {1, 32} + assert Nx.shape(b1) == {32} + assert Nx.type(b0k) == {:f, 32} + assert Nx.type(b0b) == {:f, 32} + assert Nx.type(b1k) == {:f, 32} + assert Nx.type(b1b) == {:f, 32} + assert Nx.type(k1) == {:f, 32} + assert Nx.type(b1) == {:f, 32} + + assert map_size(params) == 3 + assert map_size(block_0_params) == 1 + assert map_size(block_1_params) == 1 + end + test "predicts correctly with single dense, used once" do block = Axon.block(&Axon.dense(&1, 32)) model = block.(Axon.input("features")) @@ -5317,6 +5433,7 @@ defmodule CompilerTest do test "predicts correctly with single dense, used twice" do block = Axon.block(&Axon.dense(&1, 1)) + model = Axon.input("features") |> block.() @@ -5329,7 +5446,8 @@ defmodule CompilerTest do input = random({1, 1}) - assert predict_fn.(params, input) == input |> Axon.Layers.dense(k, b) |> Axon.Layers.dense(k, b) + assert predict_fn.(params, input) == + input |> Axon.Layers.dense(k, b) |> Axon.Layers.dense(k, b) end test "predicts correctly with multiple dense, used once" do @@ -5344,11 +5462,10 @@ defmodule CompilerTest do {init_fn, predict_fn} = Axon.build(model) assert %{ - "block_0" => - %{ - "dense_0" => %{"kernel" => k1, "bias" => b1}, - "dense_1" => %{"kernel" => k2, "bias" => b2} - } + "block_0" => %{ + "dense_0" => %{"kernel" => k1, "bias" => b1}, + "dense_1" => %{"kernel" => k2, "bias" => b2} + } } = params = init_fn.(Nx.template({1, 1}, :f32), %{}) expected_predict_fn = fn x, k1, b1, k2, b2 -> @@ -5380,11 +5497,10 @@ defmodule CompilerTest do {init_fn, predict_fn} = Axon.build(model) assert %{ - "block_0" => - %{ - "dense_0" => %{"kernel" => k1, "bias" => b1}, - "dense_1" => %{"kernel" => k2, "bias" => b2} - } + "block_0" => %{ + "dense_0" => %{"kernel" => k1, "bias" => b1}, + "dense_1" => %{"kernel" => k2, "bias" => b2} + } } = params = init_fn.(Nx.template({1, 1}, :f32), %{}) expected_predict_fn = fn x, k1, b1, k2, b2 -> @@ -5403,6 +5519,70 @@ defmodule CompilerTest do assert predict_fn.(params, input) == expected_predict_fn.(input, k1, b1, k2, b2) end + + test "predicts correctly with multiple blocks in network" do + block1 = Axon.block(&Axon.dense(&1, 32)) + block2 = Axon.block(&Axon.dense(&1, 32)) + + model = + Axon.input("features") + |> block1.() + |> block2.() + + {init_fn, predict_fn} = Axon.build(model) + + actual_predict_fn = fn x, k1, b1, k2, b2 -> + x + |> Axon.Layers.dense(k1, b1) + |> Axon.Layers.dense(k2, b2) + end + + assert %{ + "block_0" => %{ + "dense_0" => %{"kernel" => k1, "bias" => b1} + }, + "block_1" => %{ + "dense_0" => %{"kernel" => k2, "bias" => b2} + } + } = params = init_fn.(Nx.template({1, 1}, :f32), %{}) + + input = random({1, 1}) + + assert predict_fn.(params, input) == actual_predict_fn.(input, k1, b1, k2, b2) + end + + test "predicts correctly with block inside of a block" do + block = + Axon.block(fn x -> + inner_block = Axon.block(&Axon.dense(&1, 1)) + + x |> inner_block.() |> inner_block.() + end) + + model = + Axon.input("features") + |> block.() + |> block.() + + {init_fn, predict_fn} = Axon.build(model) + + actual_predict_fn = fn x, k, b -> + x + |> Axon.Layers.dense(k, b) + |> Axon.Layers.dense(k, b) + |> Axon.Layers.dense(k, b) + |> Axon.Layers.dense(k, b) + end + + assert %{ + "block_0" => %{ + "block_0" => %{"dense_0" => %{"kernel" => k, "bias" => b}} + } + } = params = init_fn.(Nx.template({1, 1}, :f32), %{}) + + input = random({1, 1}) + assert predict_fn.(params, input) == actual_predict_fn.(input, k, b) + end end describe "initializers" do