diff --git a/lib/axon/loop.ex b/lib/axon/loop.ex index e29af742..dea849b5 100644 --- a/lib/axon/loop.ex +++ b/lib/axon/loop.ex @@ -1660,15 +1660,10 @@ defmodule Axon.Loop do # TODO: Can we infer here? zero_metrics = Map.new(metric_fns, fn {k, _} -> {k, Nx.tensor(0, type: :f32)} end) - - final_metrics_map = - epoch_start..epoch_end - |> Map.new(&{&1, zero_metrics}) - |> Map.merge(loop_state.metrics) - + final_metrics_map = loop_state.metrics loop_state = %{loop_state | metrics: zero_metrics} - {status, final_metrics, state} = + {status, final_metrics_map, state} = case fire_event(:started, handler_fns, loop_state, debug?) do {:halt_epoch, state} -> {:halted, final_metrics_map, state} @@ -1722,7 +1717,7 @@ defmodule Axon.Loop do {:continue, state} -> {:cont, - {batch_fn, %{final_metrics_map | epoch => state.metrics}, + {batch_fn, Map.put(final_metrics_map, epoch, state.metrics), %State{ state | epoch: epoch + 1, @@ -1741,8 +1736,15 @@ defmodule Axon.Loop do end end - state = %State{state | metrics: final_metrics, status: status} + # Fill in epochs in case it was halted. It is a no-op otherwise. + final_metrics_map = + Enum.reduce( + state.epoch..epoch_end//1, + final_metrics_map, + &Map.put(&2, &1, zero_metrics) + ) + state = %State{state | metrics: final_metrics_map, status: status} output_transform.(state) end @@ -1919,9 +1921,9 @@ defmodule Axon.Loop do # Halts an epoch during looping defp halt_epoch(handler_fns, batch_fn, final_metrics_map, loop_state, debug?) do case fire_event(:epoch_halted, handler_fns, loop_state, debug?) do - {:halt_epoch, state} -> - {:cont, - {batch_fn, final_metrics_map, %State{state | epoch: state.epoch + 1, iteration: 0}}} + {:halt_epoch, %{epoch: epoch, metrics: metrics} = state} -> + final_metrics_map = Map.put(final_metrics_map, epoch, metrics) + {:cont, {batch_fn, final_metrics_map, %State{state | epoch: epoch + 1, iteration: 0}}} {:halt_loop, state} -> {:halt, {final_metrics_map, state}}