Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed Oct 9, 2023
1 parent 4d11383 commit a78a3a2
Showing 1 changed file with 0 additions and 46 deletions.
46 changes: 0 additions & 46 deletions test/axon/compiler_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5371,52 +5371,6 @@ defmodule CompilerTest do
assert map_size(params) == 1
end

test "initializes correctly when using block from outside scope" do
block1 = Axon.block(&Axon.dense(&1, 32))

block2 =
Axon.block(fn x ->
x |> Axon.dense(32) |> block1.()
end)

model =
Axon.input("features")
|> Axon.dense(32)
|> block1.()
|> block2.()
|> block1.()
|> block2.()

{init_fn, _} = Axon.build(model)

assert %{
"block_0" => block_0_params,
"block_1" => block_1_params,
"dense_0" => dense_0_params
} = params = init_fn.(Nx.template({1, 1}, :f32), %{})

assert %{"dense_0" => %{"kernel" => b0k, "bias" => b0b}} = block_0_params
assert %{"dense_0" => %{"kernel" => b1k, "bias" => b1b}} = block_1_params
assert %{"kernel" => k1, "bias" => b1} = dense_0_params

assert Nx.shape(b0k) == {32, 32}
assert Nx.shape(b1k) == {32, 32}
assert Nx.shape(b0b) == {32}
assert Nx.shape(b1b) == {32}
assert Nx.shape(k1) == {1, 32}
assert Nx.shape(b1) == {32}
assert Nx.type(b0k) == {:f, 32}
assert Nx.type(b0b) == {:f, 32}
assert Nx.type(b1k) == {:f, 32}
assert Nx.type(b1b) == {:f, 32}
assert Nx.type(k1) == {:f, 32}
assert Nx.type(b1) == {:f, 32}

assert map_size(params) == 3
assert map_size(block_0_params) == 1
assert map_size(block_1_params) == 1
end

test "predicts correctly with single dense, used once" do
block = Axon.block(&Axon.dense(&1, 32))
model = block.(Axon.input("features"))
Expand Down

0 comments on commit a78a3a2

Please sign in to comment.