Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔪 Remaining Sharp Bit TODOs 🔪 #9952

Open
levskaya opened this issue Mar 18, 2022 · 11 comments
Open

🔪 Remaining Sharp Bit TODOs 🔪 #9952

levskaya opened this issue Mar 18, 2022 · 11 comments
Labels
enhancement New feature or request NVIDIA GPU Issues specific to NVIDIA GPUs P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional)

Comments

@levskaya
Copy link
Collaborator

We could do with sprucing up the Sharp Bits with common problems we've encountered in user code since it was first written.

Top of the list is documenting matmul / conv op precision issues:

  • bf16 multiplication defaults! bad for simulation / classic numerics.
  • context manager for precision

We should add some other ideas here.

@levskaya levskaya added the enhancement New feature or request label Mar 18, 2022
@mattjj
Copy link
Collaborator

mattjj commented Mar 18, 2022

Just adding context: context manager for precision is defined here and there are some words about it in #6143.

@levskaya
Copy link
Collaborator Author

levskaya commented Mar 18, 2022

Others sharp bits:

  • OOB accesses don't by default raise errors and silently clip or drop! This is already in there actually, but extend it a bit and mention the mode argument for at syntax (and add link to checkify)
  • Can't cover it in any detail and not a JAX issue per se, but probably worth mentioning the general dangers of half-precision types: e.g. ease of float16 overflow/underflow and danger of accumulating into bf16.
  • accidental recompilation issues:
    • hashability of arguments / jit caching behavior
    • log-compiles feature for catching accidental recompiles.
    • worth a line: the danger of weak_type=True for triggering recompilation
  • Perhaps too esoteric / tpu-centric : we should mention the RngBitGenerator system for performance, maybe xla_tpu_spmd_rng_bit_generator_unsafe=true for oss users of spmd+rng.

@mattjj
Copy link
Collaborator

mattjj commented Mar 18, 2022

Fantastic!

On that first bullet, we could also mention checkify cc @LenaMartens

@LenaMartens
Copy link
Contributor

Nice, +1 on the danger of weak_type=True for triggering recompilation, we've had people ask for more documentation on that. I might try and add that.

@levskaya
Copy link
Collaborator Author

levskaya commented Apr 1, 2022

Also, making sure that the async set of JAX calls used in a training loop don't introduce blocking calls that will kill dispatch pipelining efficiency (e.g. trivial host-side metrics fn or similar) - one of the most common performance mistakes I see (maybe belongs in a separate performance gotchas doc... not sure)

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 1, 2022

I like the idea of having a new dedicated doc for performance tips and pitfalls

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 1, 2022

Regarding reworking the Sharp Bits doc, I recently added a section on miscellaneous divergences between numpy and JAX. It might be nice to rearrange things so all the differences between numpy and JAX are listed briefly under a single heading, perhaps with links to deeper discussion later in the doc.

@nalzok
Copy link
Contributor

nalzok commented Jul 18, 2022

Regarding the "jit caching behavior", is there any chance you could cache the compiled result to the file system so that it can persist across runs? In my development cycle, I typically change some hyperparameters and re-run the experiment. It's a little frustrating that each time I have to wait for the JIT compilation, even if I have compiled the exact same code multiple times.

I am under the impression that this won't be too hard to implement, since we already have a hashing/caching mechanism. All it takes is writing the emitted XLA program to the disk. Should I open a new issue for this?

@jakevdp
Copy link
Collaborator

jakevdp commented Jul 18, 2022

@nalzok - there is currently an implementation of this, but only for TPU. See https://github.com/google/jax/tree/main/jax/experimental/compilation_cache for details, and #2490 where this kind of request is tracked.

@sudhakarsingh27 sudhakarsingh27 added NVIDIA GPU Issues specific to NVIDIA GPUs P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional) labels Aug 10, 2022
@JeppeKlitgaard
Copy link
Contributor

I have a fairly RNG generation-heavy workload that I am running on Cloud TPU and was googling around to try and understand the xla_tpu_spmd_rng_bit_generator_unsafe flag but only found this thread and a brief mention in the JAX documentation. The quality of randomness is not critical for me. Am I right in assuming this flags improves performance but at the cost of using a less well-understood algorithm underneath?

@levskaya
Copy link
Collaborator Author

levskaya commented Apr 16, 2023

@JeppeKlitgaard - yeah, it uses an adhoc method of splitting keys that we don't have theoretical justification for (and in fact we don't really have well established statistical tests for split-chain decorrelation when it comes to splittable PRNG systems). That said, it compiles and runs fast, and it's almost certainly good enough for e.g. dropout masks in the context of SGD training of NNs (and we've used it for that with no observed ill effects for some time). I'd be a bit more careful if I were doing classic MCMC or something.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request NVIDIA GPU Issues specific to NVIDIA GPUs P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional)
Projects
None yet
Development

No branches or pull requests

7 participants