Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validation loops always recompile inside supervised training loops #416

Closed
nickgnd opened this issue Dec 12, 2022 · 12 comments
Closed

Validation loops always recompile inside supervised training loops #416

nickgnd opened this issue Dec 12, 2022 · 12 comments

Comments

@nickgnd
Copy link
Contributor

nickgnd commented Dec 12, 2022

Hi πŸ‘‹
first of all, thank you for all the great work your a putting into Nx* libraries. I'm slowly learning ML by playing with Nx and Axon, and that's a real pleasure πŸ’œ


I was playing with a simple model and i noticed that when the EXLA compiler is set globally with this command: Nx.Defn.global_default_options(compiler: EXLA) the training is slower compared when It is not set.
This surprises me a bit, but I might have overlooked something.

I shared the Livebook I used in a gist that you can use to reproduce the behaviour I described above.

https://gist.github.com/nickgnd/0d8c986b5887f4ccb7442b7145001d7f

A couple of screenshots as reference:

  • With Nx.Defn.global_default_options(compiler: EXLA) it takes 9.1 seconds to train the model (100 epochs)

Screen Shot 2022-12-12 at 11 09 31 AM

  • With Nx.Defn.global_default_options([]) it takes 2.7 seconds to train the model (100 epochs)

Screen Shot 2022-12-12 at 11 10 27 AM

Moving the call Nx.Defn.global_default_options(compiler: EXLA) from the top of the livebook right before the training does not change the behaviour.

To conclude, It can possibly be that I didn't understand how/when to use Nx.Defn.global_default_options and/or that I'm doing something wrong, sorry in advance for that πŸ™ˆ

Let me know if there is anything else that I can do to support you,
Thanks.

Keep rocking 🎸


My current setup:

  • macOS Monterey 12.6

  • erlang 25.1.2

  • elixir 1.14.2-otp-25

  • livebook 0.8.0

  • processor: 2.3 GHz Quad-Core Intel Core i7

  • memory: 32 GB

  • graphic card: Intel Iris Plus Graphics 1536 MB

@josevalim
Copy link
Contributor

Hi @nickgnd! Can you try passing compiler: EXLA, debug: true in both cases and paste the log outputs. I am suspecting we are compiling things more frequently than we should. :)

@josevalim
Copy link
Contributor

josevalim commented Dec 12, 2022

Hi @nickgnd. I cannot reproduce this.

But generally speaking, you want to set Nx.global_default_backend(EXLA.Backend) (or on your config/config.exs) and pass the compiler on demand (instead of globally), I will improve the docs.

@josevalim
Copy link
Contributor

Btw, I could confirm we are getting cache hits for EXLA. :)

@nickgnd
Copy link
Contributor Author

nickgnd commented Dec 12, 2022

Hey @josevalim sure!

So I changed the call to Loop.run to: Axon.Loop.run(train_batches, %{}, epochs: 100, compiler: EXLA, debug: true)

The log file is pretty big, more than 10000 lines, I share just the first 100 lines for each here in the comment, but then you can find the full logs here in the same gist (https://gist.github.com/nickgnd/0d8c986b5887f4ccb7442b7145001d7f)


  1. When setting the EXLA compiler globally Nx.Defn.global_default_options(compiler: EXLA)
12:26:00.912 [debug] Axon.Loop started initializing loop state
12:26:00.913 [debug] EXLA defn evaluation #Function<135.40305314/2 in Nx.Defn.Compiler.fun/2> cache miss in 1.4ms
12:26:01.174 [debug] EXLA compilation #Function<135.40305314/2 in Nx.Defn.Compiler.fun/2> cache miss in 260.6ms
12:26:01.174 [debug] EXLA device 0 lock in 0.0ms
12:26:01.174 [debug] EXLA execution on device 0 in 0.2ms
12:26:01.174 [debug] Axon.Loop finished initializing loop state in 262.5ms
12:26:01.175 [debug] Axon.Loop started running epoch 0
12:26:01.175 [debug] Axon.Loop started batch step execution
12:26:01.177 [debug] EXLA defn evaluation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache miss in 1.3ms
12:26:01.356 [debug] EXLA compilation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache miss in 178.9ms
12:26:01.356 [debug] EXLA device 0 lock in 0.1ms
12:26:01.356 [debug] EXLA execution on device 0 in 0.4ms
12:26:01.356 [debug] Axon.Loop finished batch step execution in 181.1ms
12:26:01.356 [debug] Axon.Loop fired event :iteration_completed
12:26:01.357 [debug] Axon.Loop handled event :iteration_completed with status :continue
12:26:01.357 [debug] Axon.Loop started batch step execution
12:26:01.358 [debug] EXLA defn evaluation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache miss in 1.0ms
12:26:01.522 [debug] EXLA compilation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache miss in 164.4ms
12:26:01.522 [debug] EXLA device 0 lock in 0.0ms
12:26:01.523 [debug] EXLA execution on device 0 in 0.3ms
12:26:01.523 [debug] Axon.Loop finished batch step execution in 165.9ms
12:26:01.523 [debug] Axon.Loop fired event :iteration_completed
12:26:01.523 [debug] Axon.Loop no handlers fired for event :iteration_completed
12:26:01.523 [debug] Axon.Loop started batch step execution
12:26:01.523 [debug] EXLA defn evaluation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.1ms
12:26:01.523 [debug] EXLA compilation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.0ms
12:26:01.523 [debug] EXLA device 0 lock in 0.1ms
12:26:01.524 [debug] EXLA execution on device 0 in 0.2ms
12:26:01.524 [debug] Axon.Loop finished batch step execution in 0.6ms
12:26:01.524 [debug] Axon.Loop fired event :iteration_completed
12:26:01.524 [debug] Axon.Loop no handlers fired for event :iteration_completed
12:26:01.524 [debug] Axon.Loop started batch step execution
12:26:01.524 [debug] EXLA defn evaluation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.1ms
12:26:01.524 [debug] EXLA compilation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.0ms
12:26:01.524 [debug] EXLA device 0 lock in 0.0ms
12:26:01.524 [debug] EXLA execution on device 0 in 0.3ms
12:26:01.524 [debug] Axon.Loop finished batch step execution in 0.6ms
12:26:01.524 [debug] Axon.Loop fired event :iteration_completed
12:26:01.524 [debug] Axon.Loop no handlers fired for event :iteration_completed
12:26:01.528 [debug] Axon.Loop started batch step execution
12:26:01.528 [debug] EXLA defn evaluation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.1ms
12:26:01.528 [debug] EXLA compilation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.0ms
12:26:01.528 [debug] EXLA device 0 lock in 0.0ms
12:26:01.528 [debug] EXLA execution on device 0 in 0.3ms
12:26:01.528 [debug] Axon.Loop finished batch step execution in 0.6ms
12:26:01.528 [debug] Axon.Loop fired event :iteration_completed
12:26:01.528 [debug] Axon.Loop no handlers fired for event :iteration_completed
12:26:01.528 [debug] Axon.Loop started batch step execution
12:26:01.529 [debug] EXLA defn evaluation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.1ms
12:26:01.529 [debug] EXLA compilation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.0ms
12:26:01.529 [debug] EXLA device 0 lock in 0.1ms
12:26:01.529 [debug] EXLA execution on device 0 in 0.2ms
12:26:01.529 [debug] Axon.Loop finished batch step execution in 0.6ms
12:26:01.529 [debug] Axon.Loop fired event :iteration_completed
12:26:01.529 [debug] Axon.Loop no handlers fired for event :iteration_completed
12:26:01.529 [debug] Axon.Loop started batch step execution
12:26:01.529 [debug] EXLA defn evaluation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.1ms
12:26:01.529 [debug] EXLA compilation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.0ms
12:26:01.529 [debug] EXLA device 0 lock in 0.0ms
12:26:01.530 [debug] EXLA execution on device 0 in 0.3ms
12:26:01.530 [debug] Axon.Loop finished batch step execution in 0.6ms
12:26:01.530 [debug] Axon.Loop fired event :iteration_completed
12:26:01.532 [debug] Axon.Loop no handlers fired for event :iteration_completed
12:26:01.532 [debug] Axon.Loop started batch step execution
12:26:01.532 [debug] EXLA defn evaluation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.1ms
12:26:01.532 [debug] EXLA compilation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.0ms
12:26:01.532 [debug] EXLA device 0 lock in 0.0ms
12:26:01.533 [debug] EXLA execution on device 0 in 0.2ms
12:26:01.533 [debug] Axon.Loop finished batch step execution in 0.8ms
12:26:01.533 [debug] Axon.Loop fired event :iteration_completed
12:26:01.533 [debug] Axon.Loop no handlers fired for event :iteration_completed
12:26:01.533 [debug] Axon.Loop started batch step execution
12:26:01.533 [debug] EXLA defn evaluation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.1ms
12:26:01.533 [debug] EXLA compilation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.0ms
12:26:01.533 [debug] EXLA device 0 lock in 0.0ms
12:26:01.533 [debug] EXLA execution on device 0 in 0.2ms
12:26:01.533 [debug] Axon.Loop finished batch step execution in 0.4ms
12:26:01.533 [debug] Axon.Loop fired event :iteration_completed
12:26:01.533 [debug] Axon.Loop no handlers fired for event :iteration_completed
12:26:01.533 [debug] Axon.Loop started batch step execution
12:26:01.534 [debug] EXLA defn evaluation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.1ms
12:26:01.534 [debug] EXLA compilation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.0ms
12:26:01.534 [debug] EXLA device 0 lock in 0.0ms
12:26:01.534 [debug] EXLA execution on device 0 in 0.2ms
12:26:01.534 [debug] Axon.Loop finished batch step execution in 0.4ms
12:26:01.534 [debug] Axon.Loop fired event :iteration_completed
12:26:01.534 [debug] Axon.Loop no handlers fired for event :iteration_completed
12:26:01.535 [debug] Axon.Loop started batch step execution
12:26:01.536 [debug] EXLA defn evaluation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.1ms
12:26:01.536 [debug] EXLA compilation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.0ms
12:26:01.536 [debug] EXLA device 0 lock in 0.0ms
12:26:01.536 [debug] EXLA execution on device 0 in 0.2ms
12:26:01.536 [debug] Axon.Loop finished batch step execution in 0.7ms
12:26:01.536 [debug] Axon.Loop fired event :iteration_completed
12:26:01.536 [debug] Axon.Loop no handlers fired for event :iteration_completed
12:26:01.536 [debug] Axon.Loop started batch step execution
12:26:01.536 [debug] EXLA defn evaluation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.1ms
12:26:01.536 [debug] EXLA compilation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.0ms
12:26:01.536 [debug] EXLA device 0 lock in 0.0ms
12:26:01.537 [debug] EXLA execution on device 0 in 0.2ms
  1. When compiler is unset globally Nx.Defn.global_default_options([])
12:27:55.731 [debug] Axon.Loop started initializing loop state
12:27:55.733 [debug] EXLA defn evaluation #Function<135.40305314/2 in Nx.Defn.Compiler.fun/2> cache miss in 2.1ms
12:27:55.992 [debug] EXLA compilation #Function<135.40305314/2 in Nx.Defn.Compiler.fun/2> cache miss in 258.2ms
12:27:55.992 [debug] EXLA device 0 lock in 0.0ms
12:27:55.992 [debug] EXLA execution on device 0 in 0.2ms
12:27:55.992 [debug] Axon.Loop finished initializing loop state in 260.8ms
12:27:55.992 [debug] Axon.Loop started running epoch 0
12:27:55.993 [debug] Axon.Loop started batch step execution
12:27:55.993 [debug] EXLA defn evaluation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache miss in 0.8ms
12:27:56.170 [debug] EXLA compilation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache miss in 176.5ms
12:27:56.170 [debug] EXLA device 0 lock in 0.0ms
12:27:56.171 [debug] EXLA execution on device 0 in 0.4ms
12:27:56.171 [debug] Axon.Loop finished batch step execution in 178.1ms
12:27:56.171 [debug] Axon.Loop fired event :iteration_completed
12:27:56.171 [debug] Axon.Loop handled event :iteration_completed with status :continue
12:27:56.171 [debug] Axon.Loop started batch step execution
12:27:56.173 [debug] EXLA defn evaluation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache miss in 1.4ms
12:27:56.345 [debug] EXLA compilation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache miss in 172.1ms
12:27:56.345 [debug] EXLA device 0 lock in 0.0ms
12:27:56.345 [debug] EXLA execution on device 0 in 0.3ms
12:27:56.345 [debug] Axon.Loop finished batch step execution in 174.1ms
12:27:56.346 [debug] Axon.Loop fired event :iteration_completed
12:27:56.346 [debug] Axon.Loop no handlers fired for event :iteration_completed
12:27:56.346 [debug] Axon.Loop started batch step execution
12:27:56.346 [debug] EXLA defn evaluation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.1ms
12:27:56.346 [debug] EXLA compilation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.0ms
12:27:56.346 [debug] EXLA device 0 lock in 0.0ms
12:27:56.346 [debug] EXLA execution on device 0 in 0.3ms
12:27:56.346 [debug] Axon.Loop finished batch step execution in 0.6ms
12:27:56.346 [debug] Axon.Loop fired event :iteration_completed
12:27:56.346 [debug] Axon.Loop no handlers fired for event :iteration_completed
12:27:56.346 [debug] Axon.Loop started batch step execution
12:27:56.346 [debug] EXLA defn evaluation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.1ms
12:27:56.347 [debug] EXLA compilation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.0ms
12:27:56.347 [debug] EXLA device 0 lock in 0.0ms
12:27:56.347 [debug] EXLA execution on device 0 in 0.2ms
12:27:56.347 [debug] Axon.Loop finished batch step execution in 0.4ms
12:27:56.347 [debug] Axon.Loop fired event :iteration_completed
12:27:56.350 [debug] Axon.Loop no handlers fired for event :iteration_completed
12:27:56.350 [debug] Axon.Loop started batch step execution
12:27:56.350 [debug] EXLA defn evaluation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.1ms
12:27:56.350 [debug] EXLA compilation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.0ms
12:27:56.350 [debug] EXLA device 0 lock in 0.0ms
12:27:56.350 [debug] EXLA execution on device 0 in 0.2ms
12:27:56.350 [debug] Axon.Loop finished batch step execution in 0.5ms
12:27:56.350 [debug] Axon.Loop fired event :iteration_completed
12:27:56.350 [debug] Axon.Loop no handlers fired for event :iteration_completed
12:27:56.350 [debug] Axon.Loop started batch step execution
12:27:56.351 [debug] EXLA defn evaluation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.1ms
12:27:56.351 [debug] EXLA compilation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.0ms
12:27:56.351 [debug] EXLA device 0 lock in 0.0ms
12:27:56.351 [debug] EXLA execution on device 0 in 0.2ms
12:27:56.351 [debug] Axon.Loop finished batch step execution in 0.5ms
12:27:56.351 [debug] Axon.Loop fired event :iteration_completed
12:27:56.351 [debug] Axon.Loop no handlers fired for event :iteration_completed
12:27:56.351 [debug] Axon.Loop started batch step execution
12:27:56.351 [debug] EXLA defn evaluation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.1ms
12:27:56.351 [debug] EXLA compilation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.0ms
12:27:56.351 [debug] EXLA device 0 lock in 0.0ms
12:27:56.352 [debug] EXLA execution on device 0 in 0.5ms
12:27:56.352 [debug] Axon.Loop finished batch step execution in 0.7ms
12:27:56.353 [debug] Axon.Loop fired event :iteration_completed
12:27:56.353 [debug] Axon.Loop no handlers fired for event :iteration_completed
12:27:56.353 [debug] Axon.Loop started batch step execution
12:27:56.354 [debug] EXLA defn evaluation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.1ms
12:27:56.354 [debug] EXLA compilation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.0ms
12:27:56.354 [debug] EXLA device 0 lock in 0.0ms
12:27:56.354 [debug] EXLA execution on device 0 in 0.2ms
12:27:56.354 [debug] Axon.Loop finished batch step execution in 0.5ms
12:27:56.354 [debug] Axon.Loop fired event :iteration_completed
12:27:56.354 [debug] Axon.Loop no handlers fired for event :iteration_completed
12:27:56.354 [debug] Axon.Loop started batch step execution
12:27:56.354 [debug] EXLA defn evaluation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.1ms
12:27:56.354 [debug] EXLA compilation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.0ms
12:27:56.354 [debug] EXLA device 0 lock in 0.0ms
12:27:56.355 [debug] EXLA execution on device 0 in 0.2ms
12:27:56.355 [debug] Axon.Loop finished batch step execution in 0.4ms
12:27:56.355 [debug] Axon.Loop fired event :iteration_completed
12:27:56.355 [debug] Axon.Loop no handlers fired for event :iteration_completed
12:27:56.355 [debug] Axon.Loop started batch step execution
12:27:56.355 [debug] EXLA defn evaluation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.1ms
12:27:56.355 [debug] EXLA compilation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.0ms
12:27:56.355 [debug] EXLA device 0 lock in 0.0ms
12:27:56.355 [debug] EXLA execution on device 0 in 0.2ms
12:27:56.355 [debug] Axon.Loop finished batch step execution in 0.4ms
12:27:56.355 [debug] Axon.Loop fired event :iteration_completed
12:27:56.355 [debug] Axon.Loop no handlers fired for event :iteration_completed
12:27:56.355 [debug] Axon.Loop started batch step execution
12:27:56.355 [debug] EXLA defn evaluation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.1ms
12:27:56.355 [debug] EXLA compilation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.0ms
12:27:56.355 [debug] EXLA device 0 lock in 0.0ms
12:27:56.358 [debug] EXLA execution on device 0 in 0.2ms
12:27:56.358 [debug] Axon.Loop finished batch step execution in 2.4ms
12:27:56.358 [debug] Axon.Loop fired event :iteration_completed
12:27:56.358 [debug] Axon.Loop no handlers fired for event :iteration_completed
12:27:56.358 [debug] Axon.Loop started batch step execution
12:27:56.358 [debug] EXLA defn evaluation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.1ms
12:27:56.358 [debug] EXLA compilation #Function<2.47039548/4 in Axon.Loop.build_batch_fn/2> cache hit in 0.0ms
12:27:56.359 [debug] EXLA device 0 lock in 0.0ms
12:27:56.359 [debug] EXLA execution on device 0 in 0.2ms

I quickly look at them and I didn't spot any difference :/

@nickgnd
Copy link
Contributor Author

nickgnd commented Dec 12, 2022

@josevalim thanks for trying it out and for the clarification!

Hi @nickgnd. I cannot reproduce this.

Interesting, it happens deterministically in my machine. I'm wondering what I'm doing wrong or what's wrong with my setup πŸ€”

Let me know if there is anything else that I check.
Feel free to close the issue :)

@josevalim
Copy link
Contributor

josevalim commented Dec 12, 2022 via email

@nickgnd
Copy link
Contributor Author

nickgnd commented Dec 12, 2022

Yup, already tried and indeed now is faster

it should be the fastest. :)

Do you mean that it should be the fastest overall? Or only when creating the tensors?

The livebook should probably not considered a reliable benchmark, but here the numbers:

  • without setting backend via Nx.global_default_backend(EXLA.Backend) the trainer takes 2.9s
  • with setting the backend via Nx.global_default_backend(EXLA.Backend) the trainer takes 3.1s

Maybe the dataset/model is too simple/small? Idk πŸ€”

☝️ Needless to say, in both the cases I removed the function call Nx.Defn.global_default_options(compiler: EXLA) which was misused.


Side note, I found a bit confusing that Nx.global_default_backend/1 returns the current backed and not the one that have just been set.

Nx.global_default_backend(EXLA.Backend)
{Nx.BinaryBackend, []}

Without looking at the implementation underneath, it seems that it was not properly updated.
I can happily open a PR to update the implementation (i can give it a try) or the docs at least.


Cheers ✌️

@josevalim
Copy link
Contributor

It is common for stateful operations to return their current state on update. So a PR to document that is welcome. :)

@josevalim
Copy link
Contributor

Make sure to restart the runtime on those changes, because application/pdict state can persist across evaluations.

But in general i would expect that:

  1. No backend or compiler configuration to be the slowest
  2. Backend only config to be faster than 1
  3. Compiler only config to be faster than 1

It is unclear though why you are seeing much higher times for global compiler. My guess is that we are compiling functions that we should not (for example, zeros_like).

@nickgnd
Copy link
Contributor Author

nickgnd commented Dec 13, 2022

Hey @josevalim, thanks again for getting back to me.

I created a quick benchmark to try out your suggestions.

Apparently, what slows down the training when setting the compiler globally is the validation step in the training, i mean this line of code |> Axon.Loop.validate(model, validation_data).

Mix.install([
  {:exla, "~> 0.4.1"},
  {:nx, "~> 0.4.1"},
  {:axon, "~> 0.3.1"},
  {:benchee, "~> 1.1.0"},
])

#######################################

# Benchmark: Nx compiler vs Nx backend
#
# run with `elixir axon_bench.ex`
#
# There are 4 scenarios, run them separately:
#
# 1.
# Set the backend globally
# Nx.global_default_backend(EXLA.Backend)

## 1.1 Without validation data

# Name               ips        average  deviation         median         99th %
# training         11.45       87.37 ms     Β±3.50%       86.74 ms       96.43 ms
#
# Memory usage statistics:
#
# Name        Memory usage
# training        11.91 MB


## 1.2 With validation data

# Name               ips        average  deviation         median         99th %
# training          5.15      194.00 ms     Β±6.35%      197.50 ms      205.55 ms
#
# Memory usage statistics:
#
# Name             average  deviation         median         99th %
# training        15.89 MB     Β±0.00%       15.89 MB       15.89 MB

# 2.
# Set the compilation option globally
# Nx.Defn.global_default_options(compiler: EXLA)

## 2.1 Without validation data

# Name               ips        average  deviation         median         99th %
# training         14.71       67.98 ms     Β±3.83%       67.64 ms       76.03 ms
#
# Memory usage statistics:
#
# Name        Memory usage
# training        10.47 MB

## 2.2 With validation data πŸ’₯

# Name               ips        average  deviation         median         99th %
# training          0.72         1.38 s     Β±4.70%         1.37 s         1.45 s
#
# Memory usage statistics:
#
# Name             average  deviation         median         99th %
# training        12.99 MB     Β±0.02%       12.99 MB       12.99 MB

#######################################



# Generate the data

key = Nx.Random.key(42)

{inputs, _new_key} = Nx.Random.normal(key, 0, 1, shape: {1000, 2}, type: :f16)

labels =
  Enum.map(0..999, fn _ -> Enum.random([0, 1]) end)
  |> Nx.tensor()
  |> Nx.reshape({:auto, 1})
  |> Nx.equal(Nx.tensor([0, 1]))

## Equally split the tensors in 2 parts: train and validation datasets

[x_train, x_validation] = Nx.to_batched(inputs, 500) |> Enum.to_list()
[y_train, y_validation] = Nx.to_batched(labels, 500) |> Enum.to_list()

## Divide the train dataset in batches

batch_size = 25

train_inputs = Nx.to_batched(x_train, batch_size)
train_labels = Nx.to_batched(y_train, batch_size)
train_batches = Stream.zip(train_inputs, train_labels)

validation_data = [{x_validation, y_validation}]

# Create and train the model

model =
  Axon.input("data")
  |> Axon.dense(100, activation: :sigmoid)
  |> Axon.dense(2, activation: :softmax)

loop =
  model
  |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.rmsprop(0.001))
  |> Axon.Loop.metric(:accuracy)
  # |> Axon.Loop.validate(model, validation_data) # <----- uncomment this to add the validation step

Benchee.run(
  %{
    "training" => fn -> Axon.Loop.run(loop, train_batches, %{}, epochs: 10, compiler: EXLA) end
  },
  time: 4,
  memory_time: 2
)

Said that, I'm not asking you (or anyone else) to follow up on this, I just wanted to share my last findings.
Again, feel free to close the issue, and thanks for the great work you are doing here πŸ™‡

Cheers ✌️

@seanmor5
Copy link
Contributor

@nickgnd Thanks for sharing this, it actually makes a lot of sense. In some of my recent work I noticed validation loops ALWAYS recompile, and I haven't been able to get to the source of the issue. I will leave this open until we can force cache hits in validation loops!

@seanmor5 seanmor5 changed the title Training slower when setting EXLA as compiler globally Validation loops always recompile inside supervised training loops Dec 13, 2022
@seanmor5
Copy link
Contributor

@nickgnd I believe this also was solved in #427 :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants