From 89b2e7c8e3588b8b0f8922d174f4b004f9a50e91 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Wed, 18 Oct 2023 09:36:17 -0400 Subject: [PATCH] Name hooks --- lib/axon/compiler.ex | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/lib/axon/compiler.ex b/lib/axon/compiler.ex index 4da5c7cf..f4867fcc 100644 --- a/lib/axon/compiler.ex +++ b/lib/axon/compiler.ex @@ -349,8 +349,8 @@ defmodule Axon.Compiler do res = value - |> apply_hooks(:forward, mode, hooks) - |> apply_hooks(:backward, mode, hooks) + |> apply_hooks(name, :forward, mode, hooks) + |> apply_hooks(name, :backward, mode, hooks) {res, {state, result_cache}} end @@ -824,7 +824,7 @@ defmodule Axon.Compiler do layer_input = layer_input |> safe_as_type(compute) - |> apply_hooks(:pre_forward, mode, hooks) + |> apply_hooks(name, :pre_forward, mode, hooks) {layer_input, {state, result_cache, none?}} end @@ -893,8 +893,8 @@ defmodule Axon.Compiler do %StatefulOutput{output: out, state: out_state} -> new_out = out - |> apply_hooks(:forward, mode, hooks) - |> apply_hooks(:backward, mode, hooks) + |> apply_hooks(name, :forward, mode, hooks) + |> apply_hooks(name, :backward, mode, hooks) |> safe_as_type(output) new_state = Map.put(state, name, out_state) @@ -903,8 +903,8 @@ defmodule Axon.Compiler do out -> new_out = out - |> apply_hooks(:forward, mode, hooks) - |> apply_hooks(:backward, mode, hooks) + |> apply_hooks(name, :forward, mode, hooks) + |> apply_hooks(name, :backward, mode, hooks) |> safe_as_type(output) {new_out, state} @@ -1003,7 +1003,7 @@ defmodule Axon.Compiler do init_param(layer_id, param, layer_params, parent_shapes, dtype, keys) end) - layer_params = apply_hooks(layer_params, :initialize, nil, hooks) + layer_params = apply_hooks(layer_params, name, :initialize, nil, hooks) params = if layer_params == %{} do @@ -1054,7 +1054,7 @@ defmodule Axon.Compiler do defp maybe_freeze(param, true), do: Nx.Defn.Kernel.stop_grad(param) defp maybe_freeze(param, false), do: param - defp apply_hooks(res, event, mode, hooks) do + defp apply_hooks(res, layer_name, event, mode, hooks) do hooks |> Enum.reverse() |> Enum.reduce(res, fn {on_event, on_mode, hook_fn}, expr -> @@ -1068,7 +1068,7 @@ defmodule Axon.Compiler do [hooked_g] end) else - Nx.Defn.Kernel.hook(expr, hook_fn) + Nx.Defn.Kernel.hook(expr, String.to_atom(layer_name), hook_fn) end else expr