-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove swapaxes before and after scan #7116
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #7116 +/- ##
==========================================
- Coverage 92.21% 91.79% -0.43%
==========================================
Files 101 101
Lines 16901 16900 -1
==========================================
- Hits 15586 15514 -72
- Misses 1315 1386 +71
|
My understanding is that the transpose is there so scan iterates over draws (usually 1k) instead of chains (usually 4), otherwise there's little difference between the scan and vmap option. It may however be needed to jit this function so JAX avoids duplicating memory. This didn't seem relevant in the original vmap branch. |
Ah, I knew it was there for a reason, just couldn't figure out why. Now it makes a bit more sense. |
Some more variants to throw into the bake-off
(2) showed a smaller footprint (1.8GiB), but still not as small as what's in the PR (scan for chains and vmap for draws -- which used 1.6GiB) I'm also not 100% confident of my testing methodology of using memray, and a smaller trace with only 100 draws. But FWIW, this change does allow my large model to finish sampling now with a 30gb memory limit when it was previously OOM under a 48gb limit. def test_vmap_scan(model, raw_mcmc_samples):
jax_fn = get_jax_fn(model)
def scan_over_draws(*x):
_, outs = scan(
f=lambda _, xx: ((), jax_fn(*xx)),
init=(),
xs=x,
)
return outs
final_fn = jax.vmap(
fun=scan_over_draws,
in_axes=0, # chains
out_axes=0, # output it back as the leading axis
)
ret = final_fn(*_device_put(raw_mcmc_samples, postprocessing_backend))
def test_scan_scan(model, raw_mcmc_samples):
jax_fn = get_jax_fn(model)
def scan_over_draws(*x):
_, outs = scan(
f=lambda _, xx: ((), jax_fn(*xx)),
init=(),
xs=x,
)
return outs
def scan_over_chains(*x):
_, outs = scan(
f=lambda _, xx: ((), scan_over_draws(*xx)),
init=(),
xs=x,
)
return outs
ret = scan_over_chains(*_device_put(raw_mcmc_samples, postprocessing_backend)) With the following memory footprints:
|
Did you try jitting? Does it change anything for even a single use like this? |
Regarding the best option. We added this because we were getting OOM with everything vmapped. You may want to check if your OOM is related to jax pre allocating too much? There's a config flag for that |
I'm not exactly sure which function would be jit compiled. Here I'm trying the whole thing: def test_jit_scan_vmap_wo_transpose(model, raw_mcmc_samples):
jax_fn = get_jax_fn(model)
def raw_fn():
jax_vfn = jax.vmap(jax_fn)
_, outs = scan(
lambda _, x: ((), jax_vfn(*x)),
(),
_device_put(raw_mcmc_samples, postprocessing_backend),
)
return outs
jit_fn = jax.jit(raw_fn)
ret = jit_fn()
Oh, I haven't read into jax's preallocation at all. What's the config flag you're referring to? |
Seems to only matter for GPU https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html |
I'm surprised by this but I guess JAX traced-naive transpose+scan just sucks memory wise (the docs of jax say swapaxes may return copies which I guess it is doing in your case). The problem is that I don't know if this is general. @fonnesbeck could you test if this also fixes the memory problems you were seeing in your model? |
@JasonTam do you want to try nested scan as well? Since you already did so many permutations :b That should be the most extreme at tbe opposite side of just vmap |
Nested scan is the implementation (2) in
where |
Thanks @JasonTam I missed it. Due to the fear of over fitting to one example I would perhaps go for nested Scan? WDYT? |
I too am afraid of my test set-up not generalizing well. I'm going to try to run some more tests. But also, since this is probably a widely used function for most users, I'd consider putting it under another option in I also think the existing options "vmap" and "scan" are a little misleading since both use I would definitely also appreciate feedback from @ferrine, as the previous author these bits |
@JasonTam any news? Would be great to patch this one up :) |
@ricardoV94 I haven't had time to play with this unfortunately.
When testing this larger model, some of these methods were failing. But I need to wait my turn on a cluster to make sure there's no interference, so testing these methods has been slow. |
@ricardoV94 here are some results from a larger test:
5 variables of (4 chains, 1000 draws, ...) Tested on an azure k8s cluster with Epdsv5-series vm's (3.0Ghz cpu) where each job has plenty of cpu and memory to spare.
I hope I'm understanding which method goes to which dimension correct. From these results, it does seem like |
I still like the nested scan better because we know that vmap was the source of the problem in the case that first motivated these changes. Unfortunately we don't have a way to retrieve that example but I suspect the current solution in this PR would be a regression there . |
Description
Currently, the results of
scan
are evaluated in_postprocess_samples
, and then the axes are fixed in the list comprehension[jnp.swapaxes(t, 0, 1) for _, t in outs]
. This seems to unnecessarily double the peak memory footprint of this method. Admittedly, I don't know much aboutscan
and the jaxified function, but it seems that the we may not need to transpose before and after.From what I gather, it doesnt matter if the in/out are of dimension (chains, draws, ...) or (draws, chains, ...). Avoiding the final transpose in the list comp should lower the peak memory footprint by about half (?)
In my testing, outputs were exactly the same after omitting the double transpose.
(But if the axis swaps are indeed necessary, maybe the operations can still be combined in a way that avoids the list comp at the end.)
Memory usage tested with the following:
(Note: I couldn't get the legacy chunked xmap method to work -- ran into some jax issue I couldn't decipher)
With the following results using memray:
(Notice: 2.8GiB for
test_scan_vmap
and 1.6GiB fortest_scan_vmap_wo_transpose
)Aside: I originally sought to bring back some notion of a
n_chunks
param to tradeoff runtime vs peak memory. But I guess that didn't really work out. Even at half the memory footprint,_postprocess_samples
seems very peaky.Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7116.org.readthedocs.build/en/7116/