Skip to content

Commit

Permalink
Name hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed Oct 18, 2023
1 parent 235a030 commit 89b2e7c
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions lib/axon/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,8 @@ defmodule Axon.Compiler do

res =
value
|> apply_hooks(:forward, mode, hooks)
|> apply_hooks(:backward, mode, hooks)
|> apply_hooks(name, :forward, mode, hooks)
|> apply_hooks(name, :backward, mode, hooks)

{res, {state, result_cache}}
end
Expand Down Expand Up @@ -824,7 +824,7 @@ defmodule Axon.Compiler do
layer_input =
layer_input
|> safe_as_type(compute)
|> apply_hooks(:pre_forward, mode, hooks)
|> apply_hooks(name, :pre_forward, mode, hooks)

{layer_input, {state, result_cache, none?}}
end
Expand Down Expand Up @@ -893,8 +893,8 @@ defmodule Axon.Compiler do
%StatefulOutput{output: out, state: out_state} ->
new_out =
out
|> apply_hooks(:forward, mode, hooks)
|> apply_hooks(:backward, mode, hooks)
|> apply_hooks(name, :forward, mode, hooks)
|> apply_hooks(name, :backward, mode, hooks)
|> safe_as_type(output)

new_state = Map.put(state, name, out_state)
Expand All @@ -903,8 +903,8 @@ defmodule Axon.Compiler do
out ->
new_out =
out
|> apply_hooks(:forward, mode, hooks)
|> apply_hooks(:backward, mode, hooks)
|> apply_hooks(name, :forward, mode, hooks)
|> apply_hooks(name, :backward, mode, hooks)
|> safe_as_type(output)

{new_out, state}
Expand Down Expand Up @@ -1003,7 +1003,7 @@ defmodule Axon.Compiler do
init_param(layer_id, param, layer_params, parent_shapes, dtype, keys)
end)

layer_params = apply_hooks(layer_params, :initialize, nil, hooks)
layer_params = apply_hooks(layer_params, name, :initialize, nil, hooks)

params =
if layer_params == %{} do
Expand Down Expand Up @@ -1054,7 +1054,7 @@ defmodule Axon.Compiler do
defp maybe_freeze(param, true), do: Nx.Defn.Kernel.stop_grad(param)
defp maybe_freeze(param, false), do: param

defp apply_hooks(res, event, mode, hooks) do
defp apply_hooks(res, layer_name, event, mode, hooks) do
hooks
|> Enum.reverse()
|> Enum.reduce(res, fn {on_event, on_mode, hook_fn}, expr ->
Expand All @@ -1068,7 +1068,7 @@ defmodule Axon.Compiler do
[hooked_g]
end)
else
Nx.Defn.Kernel.hook(expr, hook_fn)
Nx.Defn.Kernel.hook(expr, String.to_atom(layer_name), hook_fn)
end
else
expr
Expand Down

0 comments on commit 89b2e7c

Please sign in to comment.