-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Validation loops always recompile inside supervised training loops #416
Comments
Hi @nickgnd! Can you try passing |
Hi @nickgnd. I cannot reproduce this. But generally speaking, you want to set |
Btw, I could confirm we are getting cache hits for EXLA. :) |
Hey @josevalim sure! So I changed the call to The log file is pretty big, more than 10000 lines, I share just the first 100 lines for each here in the comment, but then you can find the full logs here in the same gist (https://gist.github.com/nickgnd/0d8c986b5887f4ccb7442b7145001d7f)
I quickly look at them and I didn't spot any difference :/ |
@josevalim thanks for trying it out and for the clarification!
Interesting, it happens deterministically in my machine. I'm wondering what I'm doing wrong or what's wrong with my setup π€ Let me know if there is anything else that I check. |
Can you try with the backend config I pasted? It should be the fastest. :)
--
*JosΓ© Valimhttps://dashbit.co/ <https://dashbit.co/>*
|
Yup, already tried and indeed now is faster
Do you mean that it should be the fastest overall? Or only when creating the tensors? The livebook should probably not considered a reliable benchmark, but here the numbers:
Maybe the dataset/model is too simple/small? Idk π€ βοΈ Needless to say, in both the cases I removed the function call Side note, I found a bit confusing that Nx.global_default_backend(EXLA.Backend)
{Nx.BinaryBackend, []} Without looking at the implementation underneath, it seems that it was not properly updated. Cheers βοΈ |
It is common for stateful operations to return their current state on update. So a PR to document that is welcome. :) |
Make sure to restart the runtime on those changes, because application/pdict state can persist across evaluations. But in general i would expect that:
It is unclear though why you are seeing much higher times for global compiler. My guess is that we are compiling functions that we should not (for example, |
Hey @josevalim, thanks again for getting back to me. I created a quick benchmark to try out your suggestions. Apparently, what slows down the training when setting the compiler globally is the validation step in the training, i mean this line of code Mix.install([
{:exla, "~> 0.4.1"},
{:nx, "~> 0.4.1"},
{:axon, "~> 0.3.1"},
{:benchee, "~> 1.1.0"},
])
#######################################
# Benchmark: Nx compiler vs Nx backend
#
# run with `elixir axon_bench.ex`
#
# There are 4 scenarios, run them separately:
#
# 1.
# Set the backend globally
# Nx.global_default_backend(EXLA.Backend)
## 1.1 Without validation data
# Name ips average deviation median 99th %
# training 11.45 87.37 ms Β±3.50% 86.74 ms 96.43 ms
#
# Memory usage statistics:
#
# Name Memory usage
# training 11.91 MB
## 1.2 With validation data
# Name ips average deviation median 99th %
# training 5.15 194.00 ms Β±6.35% 197.50 ms 205.55 ms
#
# Memory usage statistics:
#
# Name average deviation median 99th %
# training 15.89 MB Β±0.00% 15.89 MB 15.89 MB
# 2.
# Set the compilation option globally
# Nx.Defn.global_default_options(compiler: EXLA)
## 2.1 Without validation data
# Name ips average deviation median 99th %
# training 14.71 67.98 ms Β±3.83% 67.64 ms 76.03 ms
#
# Memory usage statistics:
#
# Name Memory usage
# training 10.47 MB
## 2.2 With validation data π₯
# Name ips average deviation median 99th %
# training 0.72 1.38 s Β±4.70% 1.37 s 1.45 s
#
# Memory usage statistics:
#
# Name average deviation median 99th %
# training 12.99 MB Β±0.02% 12.99 MB 12.99 MB
#######################################
# Generate the data
key = Nx.Random.key(42)
{inputs, _new_key} = Nx.Random.normal(key, 0, 1, shape: {1000, 2}, type: :f16)
labels =
Enum.map(0..999, fn _ -> Enum.random([0, 1]) end)
|> Nx.tensor()
|> Nx.reshape({:auto, 1})
|> Nx.equal(Nx.tensor([0, 1]))
## Equally split the tensors in 2 parts: train and validation datasets
[x_train, x_validation] = Nx.to_batched(inputs, 500) |> Enum.to_list()
[y_train, y_validation] = Nx.to_batched(labels, 500) |> Enum.to_list()
## Divide the train dataset in batches
batch_size = 25
train_inputs = Nx.to_batched(x_train, batch_size)
train_labels = Nx.to_batched(y_train, batch_size)
train_batches = Stream.zip(train_inputs, train_labels)
validation_data = [{x_validation, y_validation}]
# Create and train the model
model =
Axon.input("data")
|> Axon.dense(100, activation: :sigmoid)
|> Axon.dense(2, activation: :softmax)
loop =
model
|> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.rmsprop(0.001))
|> Axon.Loop.metric(:accuracy)
# |> Axon.Loop.validate(model, validation_data) # <----- uncomment this to add the validation step
Benchee.run(
%{
"training" => fn -> Axon.Loop.run(loop, train_batches, %{}, epochs: 10, compiler: EXLA) end
},
time: 4,
memory_time: 2
) Said that, I'm not asking you (or anyone else) to follow up on this, I just wanted to share my last findings. Cheers βοΈ |
@nickgnd Thanks for sharing this, it actually makes a lot of sense. In some of my recent work I noticed validation loops ALWAYS recompile, and I haven't been able to get to the source of the issue. I will leave this open until we can force cache hits in validation loops! |
Hi π
first of all, thank you for all the great work your a putting into Nx* libraries. I'm slowly learning ML by playing with Nx and Axon, and that's a real pleasure π
I was playing with a simple model and i noticed that when the EXLA compiler is set globally with this command:
Nx.Defn.global_default_options(compiler: EXLA)
the training is slower compared when It is not set.This surprises me a bit, but I might have overlooked something.
I shared the Livebook I used in a gist that you can use to reproduce the behaviour I described above.
https://gist.github.com/nickgnd/0d8c986b5887f4ccb7442b7145001d7f
A couple of screenshots as reference:
Nx.Defn.global_default_options(compiler: EXLA)
it takes 9.1 seconds to train the model (100 epochs)Nx.Defn.global_default_options([])
it takes 2.7 seconds to train the model (100 epochs)Moving the call
Nx.Defn.global_default_options(compiler: EXLA)
from the top of the livebook right before the training does not change the behaviour.To conclude, It can possibly be that I didn't understand how/when to use
Nx.Defn.global_default_options
and/or that I'm doing something wrong, sorry in advance for that πLet me know if there is anything else that I can do to support you,
Thanks.
Keep rocking πΈ
My current setup:
macOS Monterey 12.6
erlang 25.1.2
elixir 1.14.2-otp-25
livebook 0.8.0
processor: 2.3 GHz Quad-Core Intel Core i7
memory: 32 GB
graphic card: Intel Iris Plus Graphics 1536 MB
The text was updated successfully, but these errors were encountered: