Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

0.3.0: Performance Improvement #2

Merged
merged 12 commits into from
Jun 12, 2023
Merged

0.3.0: Performance Improvement #2

merged 12 commits into from
Jun 12, 2023

Conversation

JoeyTeng
Copy link
Owner

This is mostly a performance improvement version.

  1. Added an extra stage in pipeline between vertex shading and interpolation, to reduce the number of primitives involved in downstream computation, using a customisable shader method primitive_chooser. Default implementation will use the interpolated z as the depth and pick the closest valid triangle, reducing the number of involving primitives from the total number of primitives to 1 for each fragment. In a benchmark, it leads to 10x speedup. See Colab here.
  2. Eliminates almost all lax.cond to avoid extra computation and select_n HLO op as vmap + cond will convert cond to select, leading to an unconditional execution of all branches. See vmap of cond's predicate results in select, leading to unexpected compute/memory use jax-ml/jax#8409
  3. Expose loop_unroll option which may leads to a performance improvement with the cost of compilation time. Default is 1 (no unroll) which is optimal in the tested scene (960x540, many triangles).
  4. Not updating the canvas row by row, but generate a full new canvas and merge after rendering is complete.

Also

  1. Added annotations for easier tracing and profiling
  2. Make all jit with inline=True. This may not lead to any performance improvement (or degradation though).
  3. Bump minimum Python version to 3.10; lower minimum jax/jaxlib version to 0.3.25.

This should be more beneficial, as followed by the discussions in Jax repository, see
jax-ml/jax#6584 jax-ml/jax#6681 jax-ml/jax#9298 jax-ml/jax#9342
…her for functions

add `@ad_tracing_name` to most functions to assist profiling
also bump to Python 3.10

BREAKING CHANGE: Now requires Python 3.10
This is very similar to map + vmap (minibatch processing) as the inner
loop is too complex
under `vmap`, `lax.cond` are lowered to `select_n` in HLO which leads to execution in both branches,
thus fails to 1) save computation when possible; 2) prevent unexpected values to be
produced/unexpected branches to be executed (defensive), thus let the non-dummy branch to be
executed anyway and only rule-out garbage value at the final stage all together to try to improve
performance. See google/brax#8409 for more details about unconditional executation of cond under
vmap
Bump minimum Python version from 3.9 to 3.10;
lower minimum jax & jaxlib to 0.3.25.
@JoeyTeng JoeyTeng added the enhancement New feature or request label Jun 12, 2023
@JoeyTeng JoeyTeng self-assigned this Jun 12, 2023
@JoeyTeng JoeyTeng merged commit 788788a into master Jun 12, 2023
@JoeyTeng JoeyTeng deleted the 0.3.0 branch June 12, 2023 09:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant