Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dimension arrangements #142

Closed
roflmaostc opened this issue Jun 17, 2024 · 11 comments
Closed

Dimension arrangements #142

roflmaostc opened this issue Jun 17, 2024 · 11 comments

Comments

@roflmaostc
Copy link

roflmaostc commented Jun 17, 2024

Hi,

for our Julia package WaveOpticsPropagation.jl I'm trying to change our dimension layout a bit.

It seems like the (x,y) dimensions are not the most inner ones:
https://chromatix.readthedocs.io/en/latest/101/#creating-a-field

Doesn't that create a bottleneck because usually the FFTs are calculated along x,y?

For example, in Julia this makes a factor of 4 difference:

julia> using CUDA,  CUDA.CUFFT, BenchmarkTools

julia> xc = CUDA.rand((128, 128, 128)...);

julia> p = plan_fft(xc, (2,3));

julia> @benchmark CUDA.@sync $p * $x
ERROR: UndefVarError: `x` not defined
Stacktrace:
 [1] top-level scope
   @ ~/.julia/packages/BenchmarkTools/QNsku/src/execution.jl:496

julia> @benchmark CUDA.@sync $p * $xc
BenchmarkTools.Trial: 2681 samples with 1 evaluation.
 Range (min  max):  1.792 ms   25.102 ms  ┊ GC (min  max): 0.00%  96.38%
 Time  (median):     1.849 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   1.859 ms ± 449.835 μs  ┊ GC (mean ± σ):  0.69% ±  3.01%

                        ▁▄▂▂▃▆▆▅▆█▆▇▅▇▄▅█▃▃▂▃▁                 
  ▂▁▁▂▁▃▂▁▂▂▃▃▃▃▄▄▄▅▅▆██████████████████████████▆▅▅▅▄▄▃▃▃▃▃▃▂ ▅
  1.79 ms         Histogram: frequency by time        1.89 ms <

 Memory estimate: 3.86 KiB, allocs estimate: 157.

julia> p = plan_fft(xc, (1,2));

julia> @benchmark CUDA.@sync $p * $xc
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  414.595 μs   25.045 ms  ┊ GC (min  max): 0.00%  97.07%
 Time  (median):     419.495 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   425.350 μs ± 248.092 μs  ┊ GC (mean ± σ):  1.00% ±  2.90%

       ▂▆▇█▇▅▃▂▁▁▁                                               
  ▂▂▃▄▇████████████▇▅▄▃▃▃▃▂▂▂▂▂▂▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▁▂▂▁▂ ▃
  415 μs           Histogram: frequency by time          444 μs <

 Memory estimate: 4.09 KiB, allocs estimate: 161.

I was wondering how I should do the layout in Julia.
Because of the different memory layout, I would do (x,y,z, polarizations, wavelength, batch).

CC: @RainerHeintzmann

@roflmaostc
Copy link
Author

roflmaostc commented Jun 17, 2024

Also in Python this seems to make a difference:

In [12]: xc = torch.rand(128, 128, 128).cuda()

In [13]: xc.shape
Out[13]: torch.Size([128, 128, 128])

In [14]: xc.dtype
Out[14]: torch.float32

In [15]: %%timeit
    ...: torch.fft.fftn(xc, dim=(0,1))
    ...: torch.cuda.synchronize()
    ...: 
    ...: 
471 µs ± 201 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

In [16]: 

In [16]: %%timeit
    ...: torch.fft.fftn(xc, dim=(1,2))
    ...: torch.cuda.synchronize()
    ...: 
    ...: 
295 µs ± 277 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Not sure why torch is faster than Julia though. (opened an issue at CUDA.jl)

@diptodip
Copy link
Collaborator

diptodip commented Jun 17, 2024

Thanks Felix, this is a great catch! It is also sometimes slower in JAX in the order that we do things:

chromatix arrangement fft2 time for (128, 128, 128): 0.00011737400200217962 +/- 6.774095103994204e-05
inner arrangement fft2 time for (128, 128, 128): 9.083172772079706e-05 +/- 1.7062060562748385e-06
chromatix arrangement fft2 time for (128, 256, 256): 0.00021060981089249253 +/- 2.100722751556226e-05
inner arrangement fft2 time for (128, 256, 256): 0.00021538520231842995 +/- 2.5950248370836355e-06
chromatix arrangement fft2 time for (128, 512, 512): 0.0005954901920631528 +/- 3.403699545818871e-05
inner arrangement fft2 time for (128, 512, 512): 0.0006192663800902665 +/- 2.50930329870112e-06
chromatix arrangement fft2 time for (128, 1024, 1024): 0.002230075397528708 +/- 0.00036002158224289005
inner arrangement fft2 time for (128, 1024, 1024): 0.002260779996868223 +/- 8.880622902105098e-06
chromatix arrangement fft2 time for (128, 2048, 2048): 0.009496355836745352 +/- 0.0018868240660353255
inner arrangement fft2 time for (128, 2048, 2048): 0.009488527581561356 +/- 1.5524316672778253e-05

All times are in seconds averaged over 10 runs on an H100.

This layout is a choice we made at the beginning of this project (when we didn't have the vectorial dimension) so that our layout would match the default neural network image layout in JAX/TF of (batch height width channels).

I think that the difference becomes much less significant for larger size and also flip flops between which order is faster, plus we would likely have to look at what happens if we combine the simulations with neural networks to see how that would affect overall timings. So I'm not sure if we want to change anything actually.

I'm also not sure about putting the vectorial dimensions on the outside, because that would change how the matrix multiplications operate as well. But we should definitely investigate (tagging @GJBoth).

@GJBoth
Copy link
Collaborator

GJBoth commented Jun 17, 2024

Let me try and explain your observations, and why it’s probably not an issue for us. I don’t have access to a computer right now and am pretty jet lagged, but I think I’m correct. Would love to see this code for Jax though!

All this stuff has to do with memory layout - python is row major so stuff is stored along the first axes. That’s why using (1,2) as axes is slower than (2,3) - using (2,3) will load memory faster because it’s contiguous. This is why python puts the dimensional axes along the rows instead of columns.

So when there’s multiple wavelengths and polarization indeed this might be slower, but vice versa operations broadcasting over the polarization and wavelength axes will be slower. For example, all our polarization stuff will take a hit. Ideal memory layout depends on your specific use case, and we choose this axes layout to stay in line with other python approaches.

As for why it’s probably not an issue for us and the difference will be minor (but do double check): we have the batch axis in front, and this axis is usually much much bigger than the amount of wavelengths.

I’m also wondering if the Jax compiler doesn’t compile this away.

@roflmaostc
Copy link
Author

I'm surprised that jax does not follow the standard convention.

Yes, the observations are fine and consistent (in Julia just reversed).

Aren't the polarization operations usually just elementwise? Isn't the FFT the bottleneck, at the end of the day?

I guess a profiler would help to test some examples...

@diptodip
Copy link
Collaborator

diptodip commented Jun 17, 2024

The vectorial operations are matrix multiplies at each location, which means we would want those dimensions to be contiguous.

Also there is not really a standard convention (assuming you are talking about the layout of (batch height width channels)); this layout as far as I know depends on what the faster layout was depending on the available CUDA kernels at the time. So PyTorch chose a different layout than TF and both stick with their own conventions.

@diptodip
Copy link
Collaborator

And as you can see our current layout can be faster even for FFT depending on the size.

@GJBoth
Copy link
Collaborator

GJBoth commented Jun 17, 2024

To echo @diptodip, I’m nearly 100% sure both Jax and PyTorch use [H, W, C] as the CNNs operate over the channel axis.

And for the polarization, in free space yes but in samples they often become matrix multiplication as you get mixing.

@diptodip
Copy link
Collaborator

PyTorch actually uses channels first as the default but in both you have the option to choose.

@roflmaostc
Copy link
Author

But in general you would recommend a 5D or 5D scheme?

For my Julia package I'm tempted to use (x,y,z, polarization,wavelength,batch).

@diptodip
Copy link
Collaborator

diptodip commented Jun 17, 2024

We've found keeping all those dimensions to be useful (except z), and that order seems fine to me for Julia (though I don't know what happens with other Julia libraries or neural networks/if that's a concern for you).

@GJBoth
Copy link
Collaborator

GJBoth commented Jun 19, 2024

Closing for now, feel free to reopen if you have more questions.

@GJBoth GJBoth closed this as completed Jun 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants