-
Notifications
You must be signed in to change notification settings - Fork 105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Provide deterministic builds #427
Conversation
josevalim
commented
Dec 13, 2022
•
edited
Loading
edited
- Parameter IDs were removed
- Dropouts are completely removed from the network via a new :mode option
- Freezing traverse the nodes directly without relying on IDs
- Removed almost all usage of backend_copy(Nx.Defn.Expr)
- We use integers as cache keys after the cache is built
lib/axon/compiler.ex
Outdated
) | ||
|
||
# Names are computed lazily, so compute name from current | ||
# op and aggregate op_counts. | ||
name = name_fn.(op_name, op_counts) | ||
op_counts = Map.update(op_counts, op_name, 1, fn x -> x + 1 end) | ||
|
||
# TODO: Hack for dropout with key, fix with a better implementation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps it should not be based on the name but I would say it should be based on the key. Thoughts?
# Compute arguments to be forwarded and ensure `:mode` is included | ||
# for inference/training behavior dependent functions | ||
args = Enum.reverse(tensor_inputs) ++ [Keyword.put(opts, :mode, mode)] | ||
args = Enum.reverse(tensor_inputs, [Keyword.put(opts, :mode, mode)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Enum.reverse(list, tail)
is an efficient version of Enum.reverse(list) ++ tail
.
lib/axon/layers.ex
Outdated
{_, out, :train} -> | ||
out | ||
end) | ||
Nx.select(mask, input / keep_prob, Nx.tensor(0, type: Nx.type(input))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the mode check from dropout because it is no longer relevant.
@@ -114,7 +114,7 @@ defmodule CompilerTest do | |||
x2 = Axon.dense(input, 64) | |||
model = Axon.add(x1, x2) | |||
|
|||
{init_fn, _predict_fn} = Axon.build(model) | |||
{init_fn, _predict_fn} = Axon.build(model, debug: true) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only get fn stacktraces with debug: true
now.
lib/axon.ex
Outdated
@@ -2722,7 +2721,8 @@ defmodule Axon do | |||
|
|||
defp rnn_state(x, units, rnn_type, parent_name, state_name, initializer) do | |||
initializer = initializer || :glorot_uniform | |||
key = Nx.Random.key(:erlang.system_time()) |> Nx.backend_copy(Nx.Defn.Expr) | |||
# TODO: This key should be managed by the compiler |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to dropout
.
This is good to go! |
66c75bf
to
1da7d56
Compare