Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Debugging performance discrepancy between PyTorch and JAX variants of NVDiffrast #21

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

horizon-blue
Copy link
Member

The following notes are modified from the related Notion card

Benchmarking script: b3d/test/test_renderer_fps.py

Before

Output of python test/test_renderer_fps.py:

Torch
Resolution: 1024x1024, FPS: 2289.8922289071115
Resolution: 512x512, FPS: 2861.1917766822266
Resolution: 256x256, FPS: 2985.5311176707996
Resolution: 128x128, FPS: 2893.421633554084
Resolution: 64x64, FPS: 3079.929917213607
Resolution: 32x32, FPS: 3096.469365336231
JAX NVdiffrast Original
Resolution: 1024x1024, FPS: 911.1996657873235
Resolution: 512x512, FPS: 977.560392142513
Resolution: 256x256, FPS: 979.7225675373279
Resolution: 128x128, FPS: 1028.6825650452315
Resolution: 64x64, FPS: 1068.0791248191986
Resolution: 32x32, FPS: 1055.3279884702326
JAX
Resolution: 1024x1024, FPS: 801.9614317778946
Resolution: 512x512, FPS: 979.8047307075504
Resolution: 256x256, FPS: 1003.173856036569
Resolution: 128x128, FPS: 1028.2283884215265
Resolution: 64x64, FPS: 1028.7337827974036
Resolution: 32x32, FPS: 996.5328208316811
JAX through torch DLPACK
Resolution: 1024x1024, FPS: 2133.707747992481
Resolution: 512x512, FPS: 2451.4316807922564
Resolution: 256x256, FPS: 2519.926342256174
Resolution: 128x128, FPS: 2552.6166684620107
Resolution: 64x64, FPS: 2549.680826683509
Resolution: 32x32, FPS: 2548.86892264312

First change: Using lax.scan instead of for loop

This should let us get rid of some overhead from XLA…

(Note: lax.while_loop should achieve similar effect)

Related:

Torch
Resolution: 1024x1024, FPS: 2279.370582816835
Resolution: 512x512, FPS: 2852.7848003977556
Resolution: 256x256, FPS: 2974.3082078999955
Resolution: 128x128, FPS: 3037.0707208211807
Resolution: 64x64, FPS: 3070.222029625422
Resolution: 32x32, FPS: 3086.450405131643
JAX NVdiffrast Original (with lax.scan)
Resolution: 1024x1024, FPS: 1461.7752904693164
Resolution: 512x512, FPS: 1567.3793478171508
Resolution: 256x256, FPS: 1674.5947965957964
Resolution: 128x128, FPS: 1648.1440076389122
Resolution: 64x64, FPS: 1716.1011730751213
Resolution: 32x32, FPS: 1643.5740481460973
JAX (with lax.scan)
Resolution: 1024x1024, FPS: 733.1984047710031
Resolution: 512x512, FPS: 843.3599566969871
Resolution: 256x256, FPS: 875.2602479306937
Resolution: 128x128, FPS: 890.9948700344285
Resolution: 64x64, FPS: 881.9893748380823
Resolution: 32x32, FPS: 881.0593119580953
JAX through torch DLPACK
Resolution: 1024x1024, FPS: 2138.8324983108323
Resolution: 512x512, FPS: 2495.4761617194094
Resolution: 256x256, FPS: 2419.0390063395685
Resolution: 128x128, FPS: 2492.704287093078
Resolution: 64x64, FPS: 2331.9294956581107
Resolution: 32x32, FPS: 2015.288947657216

Second change: Removing unnecessarycudaStreamSynchronize(stream)

Disclaimer: I’m not certain about this change, since I’m new to CUDA programming.

It looks like we’re calling cudaStreamSynchronize(stream) a lot in the definition of the JAX rasterize wrapper code (e.g. jax_rasterize_gl.cpp). However, except for debugging, we probably don’t want to block CPU until the stream has finished execution?

By deleting cudaStreamSynchronize(stream) from the C++ implementations, we can see another performance bump on the JAX rasterizer:

Torch
Resolution: 1024x1024, FPS: 2252.760290699954
Resolution: 512x512, FPS: 2845.9962164700264
Resolution: 256x256, FPS: 2973.428943105569
Resolution: 128x128, FPS: 3034.5482222337982
Resolution: 64x64, FPS: 3070.316423172255
Resolution: 32x32, FPS: 3083.6706696094448
JAX NVdiffrast Original (+lax.scan, -stream synchronize)
Resolution: 1024x1024, FPS: 1990.3167714979204
Resolution: 512x512, FPS: 2224.1793298885977
Resolution: 256x256, FPS: 2351.079125399175
Resolution: 128x128, FPS: 2398.370552843241
Resolution: 64x64, FPS: 2321.356760696755
Resolution: 32x32, FPS: 2422.3360854559483
JAX (with lax.scan)
Resolution: 1024x1024, FPS: 719.8324700203096
Resolution: 512x512, FPS: 813.4414067909095
Resolution: 256x256, FPS: 866.9300890368162
Resolution: 128x128, FPS: 886.0318977938898
Resolution: 64x64, FPS: 874.738734019031
Resolution: 32x32, FPS: 879.9639144436635
JAX through torch DLPACK
Resolution: 1024x1024, FPS: 2150.022118915184
Resolution: 512x512, FPS: 2510.152393628481
Resolution: 256x256, FPS: 2445.71392585604
Resolution: 128x128, FPS: 2503.38057221437
Resolution: 64x64, FPS: 2480.861957195516
Resolution: 32x32, FPS: 2505.917216327311
  • Note 1: It seems like removing all Stream synchronization from b3d version of the renderer can result in nondeterministic CUDA error. I haven’t take a super close look into the b3d-version of the renderer (aka JAX) to find out what are removable, so the numbers are not included above. Though even when it doesn't error out, we don't see the same performance boost:

    JAX (+lax.scan, -stream synchronize)
    Resolution: 1024x1024, FPS: 870.6132853762722
    Resolution: 512x512, FPS: 1024.7730832097457
    Resolution: 256x256, FPS: 1076.2360742862113
    Resolution: 128x128, FPS: 1100.484477183132
    Resolution: 64x64, FPS: 1077.2718283270622
    Resolution: 32x32, FPS: 1078.6340247080325
  • Note 2: After removing the unnecessary cudaStreamSynchronize call, the output of JAX NVDiffrast is still the same as the PyTorch version:

    Torch
    tensor(1.0373e+10, device='cuda:0')
    Resolution: 1024x1024, FPS: 2289.497242321952
    tensor(2.5595e+09, device='cuda:0')
    Resolution: 512x512, FPS: 2802.909093825159
    tensor(6.5713e+08, device='cuda:0')
    Resolution: 256x256, FPS: 2985.9902140089316
    tensor(1.6092e+08, device='cuda:0')
    Resolution: 128x128, FPS: 2934.1425775210882
    tensor(36074596., device='cuda:0')
    Resolution: 64x64, FPS: 3080.278247790938
    tensor(15485054., device='cuda:0')
    Resolution: 32x32, FPS: 3090.753539481625
    JAX NVdiffrast Original (+lax.scan, -stream synchronize)
    10227558000.0
    Resolution: 1024x1024, FPS: 1974.4693350983327
    2544131600.0
    Resolution: 512x512, FPS: 2218.886785725244
    655813100.0
    Resolution: 256x256, FPS: 2291.8904977588204
    160595570.0
    Resolution: 128x128, FPS: 2384.597322781374
    36074596.0
    Resolution: 64x64, FPS: 2390.8363615635644
    15252779.0
    Resolution: 32x32, FPS: 2363.0174571177786

:) I'm just pushing the code here so people can give it a try. Even though we're only tweaking the rasterization operator here, this can give us some ideas about how to improve the performance on the overall rendering pipeline. @nishadgothoskar

@@ -109,7 +106,7 @@ void jax_rasterize_fwd_original_gl(cudaStream_t stream,
const int32_t* rangesPtr = 0;
const int32_t* triPtr = tri;
int vtxPerInstance = d.num_vertices;
rasterizeRender(NVDR_CTX_PARAMS, s, stream, posPtr, posCount, vtxPerInstance, triPtr, triCount, rangesPtr, width, height, depth, peeling_idx);
rasterizeRender(NVDR_CTX_PARAMS, s, stream, posPtr, posCount, vtxPerInstance, triPtr, triCount, ranges, width, height, depth, peeling_idx);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this change will actually modify the behavior right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For our benchmark, passing in 0 or ranges doesn't seem to have any difference on the output and the compute time (because we're just rendering a single scene).

I thought it was a mistake not to pass in the actual range (since we are using it in the current b3d renderer), but I can revert this change if this is intentional :)

@horizon-blue
Copy link
Member Author

horizon-blue commented May 24, 2024

Okay, I just reverted the change to rangesPtr.

Another thing that's worth mentioning is that, even though we were rendering 1000 times in our benchmark, the entire loop can still be executed in less than 0.5s, so the overhead from XLA can still dominate. To see when JAX (and XLA) starts to shine, we can make the compute graph bigger, e.g. by rendering 50,000 times instead

Torch
Resolution: 1024x1024, FPS: 2134.2519796143943
Resolution: 512x512, FPS: 2596.5515342340595)
Resolution: 256x256, FPS: 2717.106289678845
Resolution: 128x128, FPS: 2737.943357193579
Resolution: 64x64, FPS: 2774.9455110729173
Resolution: 32x32, FPS: 2783.2728427894226
JAX NVdiffrast Original (+lax.scan, -stream synchronize)
Resolution: 1024x1024, FPS: 2394.8623978344613
Resolution: 512x512, FPS: 2676.379275207601
Resolution: 256x256, FPS: 2749.8181835823625
Resolution: 128x128, FPS: 2814.844443132049
Resolution: 64x64, FPS: 2838.0698350321295
Resolution: 32x32, FPS: 2805.812199341117

JAX's performance number starts to catch up here, because jax.lax.scan compiles the entire loop into a single XLA WhileOp, which lets us bypass the overhead from executing the individual operators.

(Though it seems like PyTorch is also introducing their own while_loop primitive, so it might be possible to make the PyTorch version faster with it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants