diff --git a/lib/axon.ex b/lib/axon.ex index 7d07b0de..94c6e91c 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -240,9 +240,9 @@ defmodule Axon do to inference function except: * `:name` - layer name. + * `:op_name` - layer operation for inspection and building parameter map. + * `:mode` - if the layer should run only on `:inference` or `:train`. Defaults to `:both` - * `:op_name` - layer operation for inspection and building parameter - map. Note this means your layer should not use these as input options, as they will always be dropped during inference compilation. @@ -268,22 +268,23 @@ defmodule Axon do params = Enum.reverse(params) args = Enum.reverse(args) + {mode, opts} = Keyword.pop(opts, :mode, :both) {name, opts} = Keyword.pop(opts, :name) - {op_name, layer_opts} = Keyword.pop(opts, :op_name, :custom) - - {id, name} = unique_identifiers(op_name, name) - - axon_node = make_node(id, op, name, op_name, inputs, params, args, layer_opts) + {op_name, opts} = Keyword.pop(opts, :op_name, :custom) + name = name(op_name, name) + id = System.unique_integer([:positive, :monotonic]) + axon_node = make_node(id, op, name, op_name, mode, inputs, params, args, opts) %Axon{output: id, nodes: Map.put(updated_nodes, id, axon_node)} end - defp make_node(id, op, name, op_name, inputs, params, args, layer_opts) do + defp make_node(id, op, name, op_name, mode, inputs, params, args, layer_opts) do {:current_stacktrace, [_process_info, _axon_layer | stacktrace]} = Process.info(self(), :current_stacktrace) %Axon.Node{ id: id, + mode: mode, name: name, parent: inputs, parameters: params, @@ -340,10 +341,7 @@ defmodule Axon do initializer = validate_initializer!(opts[:initializer]) type = opts[:type] || {:f, 32} - id = System.unique_integer([:positive, :monotonic]) - %Axon.Parameter{ - id: id, name: name, shape: shape, type: type, @@ -1399,7 +1397,8 @@ defmodule Axon do layer(dropout, [x, key_state], name: opts[:name], rate: opts[:rate], - op_name: dropout + op_name: dropout, + mode: :train ) end @@ -2174,8 +2173,9 @@ defmodule Axon do def lstm(%Axon{} = x, units, opts) when is_integer(units) and units > 0 and is_list(opts) do {recurrent_initializer, opts} = Keyword.pop(opts, :recurrent_initializer, :glorot_uniform) - c = rnn_state(x, units, :lstm, opts[:name], "c", recurrent_initializer) - h = rnn_state(x, units, :lstm, opts[:name], "h", recurrent_initializer) + {seed, opts} = Keyword.pop_lazy(opts, :seed, fn -> :erlang.system_time() end) + c = rnn_state(x, units, :lstm, opts[:name], "c", recurrent_initializer, seed) + h = rnn_state(x, units, :lstm, opts[:name], "h", recurrent_initializer, seed) lstm(x, {c, h}, units, opts) end @@ -2372,7 +2372,8 @@ defmodule Axon do when is_integer(units) and units > 0 when is_list(opts) do {recurrent_initializer, opts} = Keyword.pop(opts, :recurrent_initializer, :glorot_uniform) - h = rnn_state(x, units, :gru, opts[:name], "h", recurrent_initializer) + {seed, opts} = Keyword.pop_lazy(opts, :seed, fn -> :erlang.system_time() end) + h = rnn_state(x, units, :gru, opts[:name], "h", recurrent_initializer, seed) gru(x, {h}, units, opts) end @@ -2549,8 +2550,9 @@ defmodule Axon do def conv_lstm(%Axon{} = x, units, opts) when is_integer(units) and units > 0 and is_list(opts) do {recurrent_initializer, opts} = Keyword.pop(opts, :recurrent_initializer, :glorot_uniform) - c = rnn_state(x, units, :conv_lstm, opts[:name], "c", recurrent_initializer) - h = rnn_state(x, units, :conv_lstm, opts[:name], "h", recurrent_initializer) + {seed, opts} = Keyword.pop_lazy(opts, :seed, fn -> :erlang.system_time() end) + c = rnn_state(x, units, :conv_lstm, opts[:name], "c", recurrent_initializer, seed) + h = rnn_state(x, units, :conv_lstm, opts[:name], "h", recurrent_initializer, seed) conv_lstm(x, {c, h}, units, opts) end @@ -2727,9 +2729,12 @@ defmodule Axon do {output_sequence, {new_c, new_h}} end - defp rnn_state(x, units, rnn_type, parent_name, state_name, initializer) do + defp rnn_state(x, units, rnn_type, parent_name, state_name, initializer, seed) do initializer = initializer || :glorot_uniform - key = Nx.Random.key(:erlang.system_time()) |> Nx.backend_copy(Nx.Defn.Expr) + key = Nx.Random.key(seed) |> Nx.backend_copy(Nx.BinaryBackend) + + key_state = + param("key", fn _ -> Nx.shape(key) end, type: {:u, 32}, initializer: fn _, _ -> key end) name = case parent_name do @@ -2742,7 +2747,7 @@ defmodule Axon do "#{parent_name}_#{state_name}_hidden_state" end - fun = fn inputs, opts -> + fun = fn inputs, key, _opts -> shape = Axon.Shape.rnn_hidden_state(Nx.shape(inputs), units, rnn_type) case initializer do @@ -2758,7 +2763,7 @@ defmodule Axon do fun.(shape, {:f, 32}) arity == 3 -> - fun.(shape, {:f, 32}, opts[:key]) + fun.(shape, {:f, 32}, key) true -> raise ArgumentError, "bad arity for initializer" @@ -2766,7 +2771,7 @@ defmodule Axon do end end - layer(fun, [x], name: name, op_name: :recurrent_state, key: key) + layer(fun, [x, key_state], name: name, op_name: :recurrent_state) end @doc """ @@ -2881,51 +2886,44 @@ defmodule Axon do the update process. """ @doc type: :model - def freeze(%Axon{} = model, fun_or_predicate \\ :all) do - parameters_per_layer = - reduce_nodes(model, [], fn %Axon.Node{parameters: params}, acc -> - layer_params = - Enum.reduce(params, [], fn param, inner_acc -> - [param | inner_acc] - end) + def freeze(model, fun_or_predicate \\ :all) do + freeze(model, fun_or_predicate, true) + end - [layer_params | acc] - end) + defp freeze(%Axon{output: id, nodes: nodes} = axon, fun_or_predicate, flag) do + {nodes, _} = traverse_nodes(id, nodes, [], MapSet.new()) - parameters_to_freeze = + nodes = case fun_or_predicate do :all -> - List.flatten(parameters_per_layer) + freeze_nodes(nodes, flag) [{:up, n}] -> - parameters_per_layer - |> Enum.reverse() - |> Enum.take(n) - |> List.flatten() + {pre, post} = Enum.split(nodes, n) + freeze_nodes(pre, flag) ++ post [{:down, n}] -> - parameters_per_layer - |> Enum.reverse() - |> Enum.drop(n) - |> List.flatten() + {pre, post} = Enum.split(nodes, -n) + pre ++ freeze_nodes(post, flag) fun -> - parameters_per_layer - |> List.flatten() - |> Enum.filter(fun) + Enum.map(nodes, fn %Axon.Node{parameters: params} = axon_node -> + %{ + axon_node + | parameters: + Enum.map(params, fn p -> + if fun.(p), do: %{p | frozen: flag}, else: p + end) + } + end) end - map_nodes(model, fn %Axon.Node{parameters: params} = axon_node -> - frozen_params = - Enum.map(params, fn %{id: param_id} = v -> - if Enum.any?(parameters_to_freeze, fn %{id: id} -> param_id == id end) do - %{v | frozen: true} - else - v - end - end) + %{axon | nodes: Map.new(nodes, fn %{id: id} = node -> {id, node} end)} + end - %{axon_node | parameters: frozen_params} + defp freeze_nodes(nodes, flag) do + Enum.map(nodes, fn %Axon.Node{parameters: params} = axon_node -> + %{axon_node | parameters: Enum.map(params, fn p -> %{p | frozen: flag} end)} end) end @@ -2960,52 +2958,8 @@ defmodule Axon do the update process. """ @doc type: :model - def unfreeze(%Axon{} = model, fun_or_predicate \\ :all) do - parameters_per_layer = - reduce_nodes(model, [], fn %Axon.Node{parameters: params}, acc -> - layer_params = - Enum.reduce(params, [], fn param, inner_acc -> - [param | inner_acc] - end) - - [layer_params | acc] - end) - - parameters_to_freeze = - case fun_or_predicate do - :all -> - List.flatten(parameters_per_layer) - - [{:up, n}] -> - parameters_per_layer - |> Enum.reverse() - |> Enum.take(n) - |> List.flatten() - - [{:down, n}] -> - parameters_per_layer - |> Enum.reverse() - |> Enum.drop(n) - |> List.flatten() - - fun -> - parameters_per_layer - |> List.flatten() - |> Enum.filter(fun) - end - - map_nodes(model, fn %Axon.Node{parameters: params} = axon_node -> - frozen_params = - Enum.map(params, fn %{id: param_id} = v -> - if Enum.any?(parameters_to_freeze, fn %{id: id} -> param_id == id end) do - %{v | frozen: false} - else - v - end - end) - - %{axon_node | parameters: frozen_params} - end) + def unfreeze(model, fun_or_predicate \\ :all) do + freeze(model, fun_or_predicate, false) end @doc """ @@ -3401,7 +3355,7 @@ defmodule Axon do ## Options - * `:mode` - one of `:inference` or `:training`. Forwarded to layers + * `:mode` - one of `:inference` or `:train`. Forwarded to layers to control differences in compilation at training or inference time. Defaults to `:inference` @@ -3480,7 +3434,7 @@ defmodule Axon do ## Options - * `:mode` - one of `:inference` or `:training`. Forwarded to layers + * `:mode` - one of `:inference` or `:train`. Forwarded to layers to control differences in compilation at training or inference time. Defaults to `:inference` @@ -3544,7 +3498,7 @@ defmodule Axon do ## Options - * `:mode` - one of `:inference` or `:training`. Forwarded to layers + * `:mode` - one of `:inference` or `:train`. Forwarded to layers to control differences in compilation at training or inference time. Defaults to `:inference` @@ -3636,7 +3590,7 @@ defmodule Axon do @doc type: :model def serialize(%Axon{output: id, nodes: nodes}, params, opts \\ []) do Logger.warning( - "Attempting to serialize an Axon model. Serialiation is discouraged" <> + "Attempting to serialize an Axon model. Serialization is discouraged" <> " and will be deprecated, then removed in future releases. You should" <> " keep your model definitions as code and serialize your parameters using" <> " `Nx.serialize/2`." @@ -3782,27 +3736,22 @@ defmodule Axon do # Names are generated lazily at inspect, initialization, and compile # time, so for name we return a function which takes `op` and `op_count` # and returns a unique name for the given model. - defp unique_identifiers(type, nil) do - id = System.unique_integer([:positive, :monotonic]) - - name = fn op, op_counts -> + defp name(type, nil) do + fn op, op_counts -> count = op_counts[op] || 0 Atom.to_string(type) <> "_#{count}" end - - {id, name} end - defp unique_identifiers(_type, name_fn) when is_function(name_fn, 2) do - id = System.unique_integer([:positive, :monotonic]) - {id, name_fn} + defp name(_type, name_fn) when is_function(name_fn, 2) do + name_fn end - defp unique_identifiers(_type, name) when is_binary(name) do - {System.unique_integer([:positive, :monotonic]), fn _, _ -> name end} + defp name(_type, name) when is_binary(name) do + fn _, _ -> name end end - defp unique_identifiers(_, name) do + defp name(_type, name) do raise ArgumentError, "expected layer name to be a binary, a function or nil, " <> "got: #{inspect(name)}" diff --git a/lib/axon/compiler.ex b/lib/axon/compiler.ex index c83428bf..470c18ee 100644 --- a/lib/axon/compiler.ex +++ b/lib/axon/compiler.ex @@ -34,17 +34,22 @@ defmodule Axon.Compiler do debug? = Keyword.get(opts, :debug, false) mode = Keyword.get(opts, :mode, :inference) key = Keyword.get_lazy(opts, :key, fn -> Nx.Random.key(:erlang.system_time()) end) - key = Nx.backend_copy(key, Nx.Defn.Expr) + config = %{mode: mode, key: key, debug?: debug?} + # TODO: Key should not be part of model funs because it is not deterministic + # Once we remove it here, we should change Nx.backend_copy above to Nx.BinaryBackend {time, {root_id, {cache, _op_counts}}} = :timer.tc(fn -> - to_model_funs(id, nodes, {%{}, %{}}, mode, key) + to_model_funs(id, nodes, {%{}, %{}}, config) end) if debug? do Logger.debug("Axon finished graph traversal in #{us_to_ms(time)}ms") end + predict_cache = + Map.new(cache, fn {_, {int_id, %{predict: predict}}} -> {int_id, %{predict: predict}} end) + predict_fun = fn params, inputs -> {:current_stacktrace, [_process_info, _fn | stacktrace]} = Process.info(self(), :current_stacktrace) @@ -54,13 +59,27 @@ defmodule Axon.Compiler do case mode do :train -> {pred_expr, {state_expr, _}} = - cache[root_id][:predict].(params, inputs, %{}, cache, %{}, stacktrace) + predict_cache[root_id][:predict].( + params, + inputs, + %{}, + predict_cache, + %{}, + stacktrace + ) %{prediction: pred_expr, state: state_expr} :inference -> {pred_expr, _} = - cache[root_id][:predict].(params, inputs, %{}, cache, %{}, stacktrace) + predict_cache[root_id][:predict].( + params, + inputs, + %{}, + predict_cache, + %{}, + stacktrace + ) pred_expr end @@ -80,6 +99,9 @@ defmodule Axon.Compiler do result end + init_cache = Map.new(cache, fn {_, {int_id, funs}} -> {int_id, funs} end) + key = Nx.backend_copy(key, Nx.Defn.Expr) + init_fun = fn template, init_params -> {:current_stacktrace, [_process_info, _fn | stacktrace]} = Process.info(self(), :current_stacktrace) @@ -88,7 +110,8 @@ defmodule Axon.Compiler do :timer.tc(fn -> param_keys = get_keys(nodes, key) - {_, {params, _}} = cache[root_id][:init].(template, cache, %{}, stacktrace, param_keys) + {_, {params, _}} = + init_cache[root_id][:init].(template, init_cache, %{}, stacktrace, param_keys) params end) @@ -113,23 +136,23 @@ defmodule Axon.Compiler do op_counts = Map.update(op_counts, op, 1, &(&1 + 1)) keys = - Enum.reduce(params, keys, fn %Axon.Parameter{name: param_name, initializer: fun}, - keys -> - {:arity, arity} = Function.info(fun, :arity) + Enum.reduce(params, keys, fn + %Axon.Parameter{name: param_name, initializer: fun}, keys -> + {:arity, arity} = Function.info(fun, :arity) - cond do - arity == 2 -> - keys + cond do + arity == 2 -> + keys - arity == 3 -> - <> = - :erlang.md5(name <> "." <> param_name) + arity == 3 -> + <> = + :erlang.md5(name <> "." <> param_name) - [{{id, param_name}, data} | keys] + [{{id, param_name}, data} | keys] - true -> - raise ArgumentError, "bad initializer arity" - end + true -> + raise ArgumentError, "bad initializer arity" + end end) {keys, op_counts} @@ -182,13 +205,17 @@ defmodule Axon.Compiler do " output, use `Axon.container`" end - defp to_model_funs(id, nodes, {cache, op_counts}, mode, key) do + defp to_model_funs(id, nodes, {cache, op_counts}, config) do case cache do - %{^id => _} -> - {id, {cache, op_counts}} + %{^id => {int_id, _}} -> + {int_id, {cache, op_counts}} %{} -> - recur_model_funs(nodes[id], nodes, {cache, op_counts}, mode, key) + {id, model_funs, cache, op_counts} = + recur_model_funs(nodes[id], nodes, {cache, op_counts}, config) + + int_id = map_size(cache) + {int_id, {Map.put(cache, id, {int_id, model_funs}), op_counts}} end end @@ -226,16 +253,25 @@ defmodule Axon.Compiler do {parent_shape, {Map.merge(parent_params, params), result_cache}} end + # If the node is ignored for the current mode, we pass through and recur next + defp recur_model_funs( + %Axon.Node{mode: node_mode, parent: [parent | _]}, + nodes, + {cache, op_counts}, + config + ) + when node_mode != :both and node_mode != config.mode do + recur_model_funs(nodes[parent], nodes, {cache, op_counts}, config) + end + defp recur_model_funs( %Axon.Node{id: id, op: :constant, opts: [value: tensor], policy: %{output: output}}, _nodes, {cache, op_counts}, - _, _ ) do op_counts = Map.update(op_counts, :constant, 1, fn x -> x + 1 end) - - tensor = Nx.backend_copy(tensor, Nx.Defn.Expr) + tensor = Nx.backend_copy(tensor, Nx.BinaryBackend) predict_fun = fn _params, _inputs, state, _cache, result_cache, _fn_stacktrace -> out = safe_as_type(tensor, output) @@ -247,8 +283,7 @@ defmodule Axon.Compiler do end model_funs = %{predict: predict_fun, init: init_fun} - - {id, {Map.put(cache, id, model_funs), op_counts}} + {id, model_funs, cache, op_counts} end defp recur_model_funs( @@ -261,8 +296,7 @@ defmodule Axon.Compiler do }, _nodes, {cache, op_counts}, - mode, - _key + %{mode: mode} ) do name = name_fn.(:input, op_counts) op_counts = Map.update(op_counts, :input, 1, fn x -> x + 1 end) @@ -287,18 +321,16 @@ defmodule Axon.Compiler do end model_funs = %{predict: predict_fun, init: init_fun} - - {id, {Map.put(cache, id, model_funs), op_counts}} + {id, model_funs, cache, op_counts} end defp recur_model_funs( %Axon.Node{id: id, op: :optional, parent: [parent]}, nodes, {cache, op_counts}, - mode, - key + config ) do - {parent_id, {cache, op_counts}} = to_model_funs(parent, nodes, {cache, op_counts}, mode, key) + {parent_id, {cache, op_counts}} = to_model_funs(parent, nodes, {cache, op_counts}, config) predict_fun = fn params, inputs, state, cache, result_cache, fn_stacktrace -> {out, {state, result_cache}} = @@ -319,19 +351,17 @@ defmodule Axon.Compiler do end model_funs = %{predict: predict_fun, init: init_fun} - - {id, {Map.put(cache, id, model_funs), op_counts}} + {id, model_funs, cache, op_counts} end defp recur_model_funs( %Axon.Node{id: id, op: :container, parent: [parents]}, nodes, cache_and_counts, - mode, - key + config ) do {parent_ids, {cache, op_counts}} = - deep_map_reduce(parents, cache_and_counts, &to_model_funs(&1, nodes, &2, mode, key)) + deep_map_reduce(parents, cache_and_counts, &to_model_funs(&1, nodes, &2, config)) op_counts = Map.update(op_counts, :container, 1, fn x -> x + 1 end) @@ -387,16 +417,14 @@ defmodule Axon.Compiler do end model_funs = %{predict: predict_fun, init: init_fun} - - {id, {Map.put(cache, id, model_funs), op_counts}} + {id, model_funs, cache, op_counts} end defp recur_model_funs( %Axon.Node{id: id, op: :namespace, name: name_fn, parent: [parent]}, nodes, {cache, op_counts}, - mode, - key + config ) do name = name_fn.(:namespace, op_counts) # To ensure that a namespace always has the same layer names, @@ -409,7 +437,7 @@ defmodule Axon.Compiler do # we forward this name to the namespace, but everything after # it belongs to whatever namespace we're currently in {parent_id, {cache, namespace_op_counts}} = - to_model_funs(parent, nodes, {cache, namespace_op_counts}, mode, key) + to_model_funs(parent, nodes, {cache, namespace_op_counts}, config) # Update the global op_count of input layers, since they # are a global operation regardless of where they are @@ -465,9 +493,7 @@ defmodule Axon.Compiler do end model_funs = %{predict: predict_fun, init: init_fun} - - # Then we return the cache, op_counts, and original namespace - {id, {Map.put(cache, id, model_funs), op_counts}} + {id, model_funs, cache, op_counts} end defp recur_model_funs( @@ -486,8 +512,7 @@ defmodule Axon.Compiler do }, nodes, cache_and_counts, - mode, - key + %{mode: mode, debug?: debug?} = config ) when (is_function(op) or is_atom(op)) and is_list(inputs) do # Traverse to accumulate cache and get parent_ids for @@ -498,7 +523,7 @@ defmodule Axon.Compiler do Enum.map_reduce( inputs, cache_and_counts, - &to_model_funs(&1, nodes, &2, mode, key) + &to_model_funs(&1, nodes, &2, config) ) # Names are computed lazily, so compute name from current @@ -506,6 +531,8 @@ defmodule Axon.Compiler do name = name_fn.(op_name, op_counts) op_counts = Map.update(op_counts, op_name, 1, fn x -> x + 1 end) + stacktrace = if debug?, do: stacktrace, else: [] + # Each model builds two functions: predict_fun and init_fun predict_fun = &layer_predict_fun( @@ -544,7 +571,7 @@ defmodule Axon.Compiler do ) model_funs = %{predict: predict_fun, init: init_fun} - {id, {Map.put(cache, id, model_funs), op_counts}} + {id, model_funs, cache, op_counts} end defp get_input(inputs, name, optional?) do @@ -675,7 +702,7 @@ defmodule Axon.Compiler do # Compute arguments to be forwarded and ensure `:mode` is included # for inference/training behavior dependent functions - args = Enum.reverse(tensor_inputs) ++ [Keyword.put(opts, :mode, mode)] + args = Enum.reverse(tensor_inputs, [Keyword.put(opts, :mode, mode)]) # For built-in layers we always just apply the equivalent function # in Axon.Layers. The implication of this is that every function which diff --git a/lib/axon/layers.ex b/lib/axon/layers.ex index b0177c5b..99bc8f91 100644 --- a/lib/axon/layers.ex +++ b/lib/axon/layers.ex @@ -121,7 +121,7 @@ defmodule Axon.Layers do dense_impl(input, kernel, bias, opts) end - defnp dense_impl(input, kernel, bias, _opts \\ []) do + defnp dense_impl(input, kernel, bias, _opts) do assert_min_rank!("Axon.Layers.dense", "input", input, 2) input @@ -182,7 +182,7 @@ defmodule Axon.Layers do bilinear_impl(input1, input2, kernel, bias, opts) end - defnp bilinear_impl(input1, input2, kernel, bias, _opts \\ []) do + defnp bilinear_impl(input1, input2, kernel, bias, _opts) do assert_min_rank!("Axon.Layers.bilinear", "input1", input1, 2) assert_min_rank!("Axon.Layers.bilinear", "input2", input2, 2) assert_equal_rank!("Axon.Layers.bilinear", "input1", input1, "input2", input2) @@ -339,7 +339,7 @@ defmodule Axon.Layers do conv_impl(input, kernel, bias, opts) end - defnp conv_impl(input, kernel, bias, opts \\ []) do + defnp conv_impl(input, kernel, bias, opts) do assert_min_rank!("Axon.Layers.conv", "input", input, 3) assert_equal_rank!("Axon.Layers.conv", "input", input, "kernel", kernel) @@ -478,7 +478,7 @@ defmodule Axon.Layers do conv_transpose_impl(input, kernel, bias, opts) end - defnp conv_transpose_impl(input, kernel, bias, opts \\ []) do + defnp conv_transpose_impl(input, kernel, bias, opts) do assert_min_rank!("Axon.Layers.conv_transpose", "input", input, 3) assert_equal_rank!("Axon.Layers.conv_transpose", "input", input, "kernel", kernel) @@ -583,7 +583,7 @@ defmodule Axon.Layers do depthwise_conv_impl(inputs, kernel, bias, opts) end - defnp depthwise_conv_impl(input, kernel, bias, opts \\ []) do + defnp depthwise_conv_impl(input, kernel, bias, opts) do assert_min_rank!("Axon.Layers.depthwise_conv", "input", input, 3) assert_equal_rank!("Axon.Layers.depthwise_conv", "input", input, "kernel", kernel) diff --git a/lib/axon/loss_scale.ex b/lib/axon/loss_scale.ex index a8838ab7..116d1292 100644 --- a/lib/axon/loss_scale.ex +++ b/lib/axon/loss_scale.ex @@ -35,7 +35,7 @@ defmodule Axon.LossScale do Implements static loss-scale. """ def static(loss_scale \\ @default_loss_scale) do - loss_scale = Nx.backend_copy(loss_scale, Nx.Defn.Expr) + loss_scale = Nx.backend_copy(loss_scale, Nx.BinaryBackend) {fn -> init_static(loss_scale) end, &scale_static/2, &unscale_static/2} end @@ -64,7 +64,7 @@ defmodule Axon.LossScale do Implements dynamic loss-scale. """ def dynamic(loss_scale \\ @default_loss_scale, opts \\ []) do - loss_scale = Nx.backend_copy(loss_scale, Nx.Defn.Expr) + loss_scale = Nx.backend_copy(loss_scale, Nx.BinaryBackend) { fn -> init_dynamic(loss_scale) end, diff --git a/lib/axon/losses.ex b/lib/axon/losses.ex index 0396922b..34cdcaf1 100644 --- a/lib/axon/losses.ex +++ b/lib/axon/losses.ex @@ -135,12 +135,19 @@ defmodule Axon.Losses do # both and perform this whole thing. If neither is set, we set this to # nil and then avoid the weighted avg later on. weights = - case {opts[:positive_weight], opts[:negative_weight]} do - {nil, nil} -> nil - {pos, nil} -> Nx.take(Nx.tensor([1.0, pos]), y_true) - {nil, neg} -> Nx.take(Nx.tensor([neg, 1.0]), y_true) - {pos, neg} -> Nx.take(Nx.tensor([neg, pos]), y_true) - end + transform({y_true, opts[:positive_weight], opts[:negative_weight]}, fn + {_, nil, nil} -> + nil + + {y_true, pos, nil} -> + Nx.take(Nx.tensor([1.0, pos], backend: Nx.Defn.Expr), y_true) + + {y_true, nil, neg} -> + Nx.take(Nx.tensor([neg, 1.0], backend: Nx.Defn.Expr), y_true) + + {y_true, pos, neg} -> + Nx.take(Nx.tensor([neg, pos], backend: Nx.Defn.Expr), y_true) + end) # Merge types before computing loss to prevent under/overflow. This # can especially happen when targets are encoded as u8 tensors. We diff --git a/lib/axon/node.ex b/lib/axon/node.ex index 5db569d3..e50e3238 100644 --- a/lib/axon/node.ex +++ b/lib/axon/node.ex @@ -4,6 +4,7 @@ defmodule Axon.Node do defstruct [ :id, :name, + :mode, :parent, :parameters, :args, diff --git a/lib/axon/parameter.ex b/lib/axon/parameter.ex index 4c1c7bde..50096d5f 100644 --- a/lib/axon/parameter.ex +++ b/lib/axon/parameter.ex @@ -1,4 +1,4 @@ defmodule Axon.Parameter do @moduledoc false - defstruct [:id, :name, :shape, :initializer, type: {:f, 32}, frozen: false] + defstruct [:name, :shape, :initializer, type: {:f, 32}, frozen: false] end diff --git a/test/axon/compiler_test.exs b/test/axon/compiler_test.exs index 1cc70d14..792b8825 100644 --- a/test/axon/compiler_test.exs +++ b/test/axon/compiler_test.exs @@ -114,7 +114,7 @@ defmodule CompilerTest do x2 = Axon.dense(input, 64) model = Axon.add(x1, x2) - {init_fn, _predict_fn} = Axon.build(model) + {init_fn, _predict_fn} = Axon.build(model, debug: true) %Axon.CompileError{} = exception = catch_error(init_fn.(Nx.template({1, 16}, :f32), %{})) message = Exception.message(exception) @@ -1025,7 +1025,7 @@ defmodule CompilerTest do input = Nx.random_uniform({1, 1, 32}) - assert {init_fn, _predict_fn} = Axon.build(model) + assert {init_fn, _predict_fn} = Axon.build(model, mode: :train) assert %{"dropout" => %{"key" => key}} = init_fn.(input, %{}) assert_equal(key, Nx.Random.key(0)) end @@ -1073,13 +1073,13 @@ defmodule CompilerTest do test "computes forward pass with default options" do for dropout <- @dropout_layers do - model1 = apply(Axon, dropout, [Axon.input("input", shape: {nil, 1, 32})]) - input1 = Nx.random_uniform({1, 1, 32}, type: {:f, 32}) + model1 = apply(Axon, dropout, [Axon.input("input", shape: {nil, 32, 32})]) + input1 = Nx.random_uniform({1, 32, 32}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model1, mode: :train) %{prediction: result1} = predict_fn.(init_fn.(input1, %{}), input1) - assert Nx.shape(result1) == {1, 1, 32} + assert Nx.shape(result1) == {1, 32, 32} assert Nx.type(result1) == {:f, 32} assert_not_equal(result1, input1) @@ -1107,15 +1107,15 @@ defmodule CompilerTest do test "computes forward pass with custom options" do for dropout <- @dropout_layers do - opts1 = [rate: 0.25] - model1 = apply(Axon, dropout, [Axon.input("input", shape: {nil, 1, 32}), opts1]) - input1 = Nx.random_uniform({1, 1, 32}, type: {:f, 32}) + opts1 = [rate: 0.5] + model1 = apply(Axon, dropout, [Axon.input("input", shape: {nil, 32, 128}), opts1]) + input1 = Nx.random_uniform({1, 32, 128}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model1, mode: :train) %{prediction: result} = predict_fn.(init_fn.(input1, %{}), input1) - assert Nx.shape(result) == {1, 1, 32} + assert Nx.shape(result) == {1, 32, 128} assert Nx.type(result) == {:f, 32} assert_not_equal(result, input1) end @@ -1129,8 +1129,8 @@ defmodule CompilerTest do input = Nx.random_uniform({1, 1, 32}) - assert {init_fn, predict_fn} = Axon.build(mp_model) - assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} + assert {init_fn, predict_fn} = Axon.build(mp_model, mode: :train) + assert Nx.type(predict_fn.(init_fn.(input, %{}), input).prediction) == {:bf, 16} end end @@ -5250,4 +5250,13 @@ defmodule CompilerTest do ) end end + + describe "determinism" do + test "builds the same model multiple times" do + builder = fn -> Axon.input("input", shape: {nil, 784}) |> Axon.dense(128) end + {_, predict_fn1} = Axon.Compiler.build(builder.(), []) + {_, predict_fn2} = Axon.Compiler.build(builder.(), []) + assert predict_fn1 == predict_fn2 + end + end end diff --git a/test/axon/loop_test.exs b/test/axon/loop_test.exs index ec0316cd..0c9f4c6a 100644 --- a/test/axon/loop_test.exs +++ b/test/axon/loop_test.exs @@ -61,7 +61,7 @@ defmodule Axon.LoopTest do test "trainer/3 returns a supervised training loop with custom loss" do model = Axon.input("input", shape: {nil, 1}) - custom_loss_fn = fn _, _ -> Nx.tensor(5.0, backend: Nx.Defn.Expr) end + custom_loss_fn = fn _, _ -> Nx.tensor(5.0, backend: Nx.BinaryBackend) end assert %Loop{init: init_fn, step: update_fn, output_transform: transform} = Loop.trainer(model, custom_loss_fn, :adam)