Skip to content

Commit

Permalink
Almost finished block
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed Oct 9, 2023
1 parent b2a7c28 commit 4d11383
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 36 deletions.
31 changes: 7 additions & 24 deletions lib/axon/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -480,19 +480,16 @@ defmodule Axon.Compiler do
&to_model_funs(&1, nodes, &2, config)
)

input_nodes = get_input_subgraph_nodes(parent, nodes, %{})
input_subgraph = %Axon{output: parent, nodes: input_nodes}

{{block_init_fun, block_predict_fun}, init_block?, block_name, block_cache, op_counts} =
{{block_init_fun, block_predict_fun}, block_name, block_cache, op_counts} =
case block_cache do
%{^block_id => {funs, name}} = block_cache ->
{funs, false, name, block_cache, op_counts}
{funs, name, block_cache, op_counts}

%{} ->
funs = build(block_fun.(input_subgraph), debug?: config.debug?)
funs = build(block_fun.(Axon.input("subgraph")), debug?: config.debug?)
name = name_fn.(:block, op_counts)
op_counts = Map.update(op_counts, :block, 1, fn x -> x + 1 end)
{funs, true, name, Map.put(block_cache, block_id, {funs, name}), op_counts}
{funs, name, Map.put(block_cache, block_id, {funs, name}), op_counts}
end

predict_fun = fn params, inputs, state, cache, result_cache, fn_stacktrace ->
Expand Down Expand Up @@ -547,7 +544,7 @@ defmodule Axon.Compiler do
end

init_fun = fn template, cache, result_cache, fn_stacktrace, keys ->
{[_parent_shape], {parent_params, result_cache, none?}} =
{[parent_shape], {parent_params, result_cache, none?}} =
Enum.map_reduce([parent_id], {%{}, result_cache, false}, fn
parent_id, {params, result_cache, none?} ->
{parent_shape, {params, result_cache}} =
Expand All @@ -568,13 +565,14 @@ defmodule Axon.Compiler do
if none? do
{%Axon.None{}, {parent_params, result_cache}}
else
template = Nx.broadcast(0.0, parent_shape)
block_params = apply(block_init_fun, [template, %{}])

params =
if block_params == %{} do
%{}
else
%{block_name => block_params}
Map.put(parent_params, block_name, block_params)
end

{pred_expr, {_, result_cache}} =
Expand Down Expand Up @@ -1094,19 +1092,4 @@ defmodule Axon.Compiler do
defp propagating_none?(_), do: false

defp us_to_ms(time), do: Float.round(time / 1000, 1)

defp get_input_subgraph_nodes(parent_id, nodes, acc) do
case nodes do
%{^parent_id => %{parent: parents} = node} ->
acc =
Enum.reduce(parents, acc, fn id, acc ->
get_input_subgraph_nodes(id, nodes, acc)
end)

Map.put(acc, parent_id, node)

%{} ->
acc
end
end
end
204 changes: 192 additions & 12 deletions test/axon/compiler_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5253,7 +5253,7 @@ defmodule CompilerTest do

assert Nx.shape(k1) == {1, 32}
assert Nx.shape(b1) == {32}
assert Nx.shape(k2) == {1, 32}
assert Nx.shape(k2) == {32, 32}
assert Nx.shape(b2) == {32}
assert Nx.type(k1) == {:f, 32}
assert Nx.type(b1) == {:f, 32}
Expand Down Expand Up @@ -5301,6 +5301,122 @@ defmodule CompilerTest do
assert map_size(params) == 1
end

test "initializes correctly with multiple blocks in network" do
block1 = Axon.block(&Axon.dense(&1, 32))
block2 = Axon.block(&Axon.dense(&1, 32))

model =
Axon.input("features")
|> block1.()
|> block2.()

{init_fn, _} = Axon.build(model)

assert %{
"block_0" =>
%{
"dense_0" => %{"kernel" => k1, "bias" => b1}
} = block_0_params,
"block_1" =>
%{
"dense_0" => %{"kernel" => k2, "bias" => b2}
} = block_1_params
} = params = init_fn.(Nx.template({1, 1}, :f32), %{})

assert Nx.shape(k1) == {1, 32}
assert Nx.shape(b1) == {32}
assert Nx.shape(k2) == {32, 32}
assert Nx.shape(b2) == {32}
assert Nx.type(k1) == {:f, 32}
assert Nx.type(b1) == {:f, 32}
assert Nx.type(k2) == {:f, 32}
assert Nx.type(b2) == {:f, 32}

# no additional dense layers in block
assert map_size(block_0_params) == 1
assert map_size(block_1_params) == 1
# no additional blocks
assert map_size(params) == 2
end

test "initializes correctly with block inside of a block" do
block =
Axon.block(fn x ->
inner_block = Axon.block(&Axon.dense(&1, 1))

x |> inner_block.() |> inner_block.()
end)

model =
Axon.input("features")
|> block.()
|> block.()

{init_fn, _} = Axon.build(model)

assert %{
"block_0" =>
%{
"block_0" => %{"dense_0" => %{"kernel" => k, "bias" => b}} = inner_block_params
} = block_params
} = params = init_fn.(Nx.template({1, 1}, :f32), %{})

assert Nx.shape(k) == {1, 1}
assert Nx.shape(b) == {1}
assert Nx.type(k) == {:f, 32}
assert Nx.type(b) == {:f, 32}

assert map_size(inner_block_params) == 1
assert map_size(block_params) == 1
assert map_size(params) == 1
end

test "initializes correctly when using block from outside scope" do
block1 = Axon.block(&Axon.dense(&1, 32))

block2 =
Axon.block(fn x ->
x |> Axon.dense(32) |> block1.()
end)

model =
Axon.input("features")
|> Axon.dense(32)
|> block1.()
|> block2.()
|> block1.()
|> block2.()

{init_fn, _} = Axon.build(model)

assert %{
"block_0" => block_0_params,
"block_1" => block_1_params,
"dense_0" => dense_0_params
} = params = init_fn.(Nx.template({1, 1}, :f32), %{})

assert %{"dense_0" => %{"kernel" => b0k, "bias" => b0b}} = block_0_params
assert %{"dense_0" => %{"kernel" => b1k, "bias" => b1b}} = block_1_params
assert %{"kernel" => k1, "bias" => b1} = dense_0_params

assert Nx.shape(b0k) == {32, 32}
assert Nx.shape(b1k) == {32, 32}
assert Nx.shape(b0b) == {32}
assert Nx.shape(b1b) == {32}
assert Nx.shape(k1) == {1, 32}
assert Nx.shape(b1) == {32}
assert Nx.type(b0k) == {:f, 32}
assert Nx.type(b0b) == {:f, 32}
assert Nx.type(b1k) == {:f, 32}
assert Nx.type(b1b) == {:f, 32}
assert Nx.type(k1) == {:f, 32}
assert Nx.type(b1) == {:f, 32}

assert map_size(params) == 3
assert map_size(block_0_params) == 1
assert map_size(block_1_params) == 1
end

test "predicts correctly with single dense, used once" do
block = Axon.block(&Axon.dense(&1, 32))
model = block.(Axon.input("features"))
Expand All @@ -5317,6 +5433,7 @@ defmodule CompilerTest do

test "predicts correctly with single dense, used twice" do
block = Axon.block(&Axon.dense(&1, 1))

model =
Axon.input("features")
|> block.()
Expand All @@ -5329,7 +5446,8 @@ defmodule CompilerTest do

input = random({1, 1})

assert predict_fn.(params, input) == input |> Axon.Layers.dense(k, b) |> Axon.Layers.dense(k, b)
assert predict_fn.(params, input) ==
input |> Axon.Layers.dense(k, b) |> Axon.Layers.dense(k, b)
end

test "predicts correctly with multiple dense, used once" do
Expand All @@ -5344,11 +5462,10 @@ defmodule CompilerTest do
{init_fn, predict_fn} = Axon.build(model)

assert %{
"block_0" =>
%{
"dense_0" => %{"kernel" => k1, "bias" => b1},
"dense_1" => %{"kernel" => k2, "bias" => b2}
}
"block_0" => %{
"dense_0" => %{"kernel" => k1, "bias" => b1},
"dense_1" => %{"kernel" => k2, "bias" => b2}
}
} = params = init_fn.(Nx.template({1, 1}, :f32), %{})

expected_predict_fn = fn x, k1, b1, k2, b2 ->
Expand Down Expand Up @@ -5380,11 +5497,10 @@ defmodule CompilerTest do
{init_fn, predict_fn} = Axon.build(model)

assert %{
"block_0" =>
%{
"dense_0" => %{"kernel" => k1, "bias" => b1},
"dense_1" => %{"kernel" => k2, "bias" => b2}
}
"block_0" => %{
"dense_0" => %{"kernel" => k1, "bias" => b1},
"dense_1" => %{"kernel" => k2, "bias" => b2}
}
} = params = init_fn.(Nx.template({1, 1}, :f32), %{})

expected_predict_fn = fn x, k1, b1, k2, b2 ->
Expand All @@ -5403,6 +5519,70 @@ defmodule CompilerTest do

assert predict_fn.(params, input) == expected_predict_fn.(input, k1, b1, k2, b2)
end

test "predicts correctly with multiple blocks in network" do
block1 = Axon.block(&Axon.dense(&1, 32))
block2 = Axon.block(&Axon.dense(&1, 32))

model =
Axon.input("features")
|> block1.()
|> block2.()

{init_fn, predict_fn} = Axon.build(model)

actual_predict_fn = fn x, k1, b1, k2, b2 ->
x
|> Axon.Layers.dense(k1, b1)
|> Axon.Layers.dense(k2, b2)
end

assert %{
"block_0" => %{
"dense_0" => %{"kernel" => k1, "bias" => b1}
},
"block_1" => %{
"dense_0" => %{"kernel" => k2, "bias" => b2}
}
} = params = init_fn.(Nx.template({1, 1}, :f32), %{})

input = random({1, 1})

assert predict_fn.(params, input) == actual_predict_fn.(input, k1, b1, k2, b2)
end

test "predicts correctly with block inside of a block" do
block =
Axon.block(fn x ->
inner_block = Axon.block(&Axon.dense(&1, 1))

x |> inner_block.() |> inner_block.()
end)

model =
Axon.input("features")
|> block.()
|> block.()

{init_fn, predict_fn} = Axon.build(model)

actual_predict_fn = fn x, k, b ->
x
|> Axon.Layers.dense(k, b)
|> Axon.Layers.dense(k, b)
|> Axon.Layers.dense(k, b)
|> Axon.Layers.dense(k, b)
end

assert %{
"block_0" => %{
"block_0" => %{"dense_0" => %{"kernel" => k, "bias" => b}}
}
} = params = init_fn.(Nx.template({1, 1}, :f32), %{})

input = random({1, 1})
assert predict_fn.(params, input) == actual_predict_fn.(input, k, b)
end
end

describe "initializers" do
Expand Down

0 comments on commit 4d11383

Please sign in to comment.