Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove swapaxes before and after scan #7116

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

JasonTam
Copy link
Contributor

@JasonTam JasonTam commented Jan 24, 2024

Description

Currently, the results of scan are evaluated in _postprocess_samples, and then the axes are fixed in the list comprehension [jnp.swapaxes(t, 0, 1) for _, t in outs]. This seems to unnecessarily double the peak memory footprint of this method. Admittedly, I don't know much about scan and the jaxified function, but it seems that the we may not need to transpose before and after.
From what I gather, it doesnt matter if the in/out are of dimension (chains, draws, ...) or (draws, chains, ...). Avoiding the final transpose in the list comp should lower the peak memory footprint by about half (?)
In my testing, outputs were exactly the same after omitting the double transpose.

(But if the axis swaps are indeed necessary, maybe the operations can still be combined in a way that avoids the list comp at the end.)

Memory usage tested with the following:

import pickle
from pathlib import Path

import jax
import jax.numpy as jnp
import pytest
from jax.experimental.maps import SerialLoop, xmap
from jax.lax import scan
from pymc.sampling.jax import _device_put, get_jaxified_graph

CUR_DIR = Path(__file__).parent

DIR_FIXTURES = CUR_DIR / "../fixtures/profiling"

PATH_MODEL = DIR_FIXTURES / "model_pm.p"
PATH_DATA = DIR_FIXTURES / "raw_mcmc_samples.p"

postprocessing_backend = None


@pytest.fixture
def model():
    return pickle.load(open(PATH_MODEL, "rb"))


@pytest.fixture
def raw_mcmc_samples():
    return pickle.load(open(PATH_DATA, "rb"))


def get_jax_fn(model):
    vars_to_sample = [
        v for v in model.unobserved_value_vars if not v.name.endswith("__")
    ]
    jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
    return jax_fn


def test_scan_vmap(model, raw_mcmc_samples):
    jax_fn = get_jax_fn(model)

    t_raw_mcmc_samples = [jnp.swapaxes(t, 0, 1) for t in raw_mcmc_samples]
    jax_vfn = jax.vmap(jax_fn)
    _, outs = scan(
        lambda _, x: ((), jax_vfn(*x)),
        (),
        _device_put(t_raw_mcmc_samples, postprocessing_backend),
    )
    ret = [jnp.swapaxes(t, 0, 1) for t in outs]


def test_scan_vmap_wo_transpose(model, raw_mcmc_samples):
    jax_fn = get_jax_fn(model)

    jax_vfn = jax.vmap(jax_fn)
    _, outs = scan(
        lambda _, x: ((), jax_vfn(*x)),
        (),
        _device_put(raw_mcmc_samples, postprocessing_backend),
    )
    ret = outs


def test_nested_vmap(model, raw_mcmc_samples):
    jax_fn = get_jax_fn(model)
    ret = jax.vmap(jax.vmap(jax_fn))(
        *_device_put(raw_mcmc_samples, postprocessing_backend)
    )


def test_looped_vmap(model, raw_mcmc_samples, num_chunks=100):
    # https://discourse.pymc.io/t/nameerror-unbound-axis-name-raised-during-transformation-of-variables-after-sample-numpyro-nuts/11167/5
    jax_fn = get_jax_fn(model)

    # dims are vars, chains, draws, ...
    raw_mcmc_samples = _device_put(raw_mcmc_samples, postprocessing_backend)
    f = jax.vmap(jax.vmap(jax_fn))
    draws = len(raw_mcmc_samples[0][0])
    segs = list(range(0, draws, draws // num_chunks)) + [draws]
    # dims are chunks, vars, chains, draws, ...
    outputs = [
        f(*[var_samples[:, i:j] for var_samples in raw_mcmc_samples])
        for i, j in zip(segs[:-1], segs[1:])
    ]
    # dims of var_chunks are chunks, chains, draws, ...
    ret = [jnp.concatenate(var_chunks, axis=1) for var_chunks in zip(*outputs)]

(Note: I couldn't get the legacy chunked xmap method to work -- ran into some jax issue I couldn't decipher)

With the following results using memray:

========================================================================= MEMRAY REPORT =========================================================================
Allocation results for src/tests/profiling/test_jax_postproc_profile.py::test_looped_vmap at the high watermark

         📦 Total memory allocated: 3.4GiB
         📏 Total allocations: 5272432
         📊 Histogram of allocation sizes: |  █ ▁    |
         🥇 Biggest allocating functions:
                - <lambda>:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/dispatch.py:164 -> 1.3GiB
                - __call__:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 761.9MiB
                - _pjit_call_impl:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/pjit.py:1214 -> 653.7MiB
                - _pjit_call_impl:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/pjit.py:1214 -> 653.7MiB
                - _pjit_call_impl:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/pjit.py:1214 -> 90.3MiB


Allocation results for src/tests/profiling/test_jax_postproc_profile.py::test_nested_vmap at the high watermark

         📦 Total memory allocated: 2.9GiB
         📏 Total allocations: 1657858
         📊 Histogram of allocation sizes: |  █ ▁    |
         🥇 Biggest allocating functions:
                - _pjit_call_impl:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/pjit.py:1214 -> 653.4MiB
                - __call__:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 653.4MiB
                - __call__:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 653.4MiB
                - __call__:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 653.4MiB
                - __call__:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 163.3MiB


Allocation results for src/tests/profiling/test_jax_postproc_profile.py::test_scan_vmap at the high watermark

         📦 Total memory allocated: 2.8GiB
         📏 Total allocations: 1416835
         📊 Histogram of allocation sizes: |  █ ▁    |
         🥇 Biggest allocating functions:
                - __call__:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 1.4GiB
                - __call__:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 761.9MiB
                - <listcomp>:/Users/jason/Wonder/  -ds/src/tests/profiling/test_jax_postproc_profile.py:50 -> 653.4MiB
                - __call__:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 18.6MiB
                - <lambda>:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/dispatch.py:164 -> 4.1MiB


Allocation results for src/tests/profiling/test_jax_postproc_profile.py::test_scan_vmap_wo_transpose at the high watermark

         📦 Total memory allocated: 1.6GiB
         📏 Total allocations: 605385
         📊 Histogram of allocation sizes: |▁ █ ▁    |
         🥇 Biggest allocating functions:
                - __call__:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 1.6GiB
                - <lambda>:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/dispatch.py:164 -> 6.6MiB
                - __call__:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/pytensor/link/c/basic.py:1767 -> 3.3MiB
                - backend_compile:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/jax/_src/compiler.py:251 -> 1.6MiB
                - __init__:/Users/jason/mambaforge/envs/  /lib/python3.10/site-packages/pytensor/graph/rewriting/basic.py:2277 -> 1.0MiB

(Notice: 2.8GiB for test_scan_vmap and 1.6GiB for test_scan_vmap_wo_transpose)

Aside: I originally sought to bring back some notion of a n_chunks param to tradeoff runtime vs peak memory. But I guess that didn't really work out. Even at half the memory footprint, _postprocess_samples seems very peaky.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7116.org.readthedocs.build/en/7116/

Copy link

codecov bot commented Jan 24, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (2da4050) 92.21% compared to head (3ab2863) 91.79%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7116      +/-   ##
==========================================
- Coverage   92.21%   91.79%   -0.43%     
==========================================
  Files         101      101              
  Lines       16901    16900       -1     
==========================================
- Hits        15586    15514      -72     
- Misses       1315     1386      +71     
Files Coverage Δ
pymc/sampling/jax.py 93.05% <100.00%> (-0.03%) ⬇️

... and 3 files with indirect coverage changes

@ricardoV94 ricardoV94 requested a review from ferrine January 24, 2024 12:30
@ricardoV94
Copy link
Member

ricardoV94 commented Jan 24, 2024

From what I gather, it doesnt matter if the in/out are of dimension (chains, draws, ...) or (draws, chains, ...). Avoiding the final transpose in the list comp should lower the peak memory footprint by about half (?)

My understanding is that the transpose is there so scan iterates over draws (usually 1k) instead of chains (usually 4), otherwise there's little difference between the scan and vmap option.

It may however be needed to jit this function so JAX avoids duplicating memory. This didn't seem relevant in the original vmap branch.

@JasonTam
Copy link
Contributor Author

Ah, I knew it was there for a reason, just couldn't figure out why. Now it makes a bit more sense.
Then maybe jax-ml/jax#2509 would be helpful, but that issue's been stale for quite some time.

@JasonTam
Copy link
Contributor Author

JasonTam commented Jan 24, 2024

Some more variants to throw into the bake-off
I was curious to see:

  1. What if we switched what scan / vmap were being used for. ie) scan over draws and vmap over chains
  2. What if we used scan for both chains and draws

(2) showed a smaller footprint (1.8GiB), but still not as small as what's in the PR (scan for chains and vmap for draws -- which used 1.6GiB)

I'm also not 100% confident of my testing methodology of using memray, and a smaller trace with only 100 draws. But FWIW, this change does allow my large model to finish sampling now with a 30gb memory limit when it was previously OOM under a 48gb limit.

def test_vmap_scan(model, raw_mcmc_samples):
    jax_fn = get_jax_fn(model)

    def scan_over_draws(*x):
        _, outs = scan(
            f=lambda _, xx: ((), jax_fn(*xx)),
            init=(),
            xs=x,
        )
        return outs

    final_fn = jax.vmap(
        fun=scan_over_draws,
        in_axes=0,  # chains
        out_axes=0,  # output it back as the leading axis
    )

    ret = final_fn(*_device_put(raw_mcmc_samples, postprocessing_backend))


def test_scan_scan(model, raw_mcmc_samples):
    jax_fn = get_jax_fn(model)
    
    def scan_over_draws(*x):
        _, outs = scan(
            f=lambda _, xx: ((), jax_fn(*xx)),
            init=(),
            xs=x,
        )
        return outs

    def scan_over_chains(*x):
        _, outs = scan(
            f=lambda _, xx: ((), scan_over_draws(*xx)),
            init=(),
            xs=x,
        )
        return outs

    ret = scan_over_chains(*_device_put(raw_mcmc_samples, postprocessing_backend))

With the following memory footprints:

================================================== MEMRAY REPORT ==================================================
Allocation results for src/tests/profiling/test_jax_postproc_profile.py::test_vmap_scan at the high watermark

         📦 Total memory allocated: 2.8GiB
         📏 Total allocations: 1451992
         📊 Histogram of allocation sizes: |  █ ▁    |
         🥇 Biggest allocating functions:
                - __call__:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 1.4GiB
                - __call__:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 761.9MiB
                - <lambda>:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/dispatch.py:164 -> 653.4MiB
                - refresh:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/pytensor/link/c/cmodule.py:851 -> 6.0MiB
                - <lambda>:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/dispatch.py:164 -> 4.1MiB


Allocation results for src/tests/profiling/test_jax_postproc_profile.py::test_scan_scan at the high watermark

         📦 Total memory allocated: 1.8GiB
         📏 Total allocations: 960009
         📊 Histogram of allocation sizes: |▁ █      |
         🥇 Biggest allocating functions:
                - __call__:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 1.7GiB
                - <lambda>:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/dispatch.py:164 -> 6.6MiB
                - __call__:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/pytensor/link/c/basic.py:1767 -> 3.3MiB
                - backend_compile:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/jax/_src/compiler.py:251 -> 2.4MiB
                - no_nan:/Users/jason/mambaforge/envs/ /lib/python3.10/site-packages/pytensor/tensor/variable.py:1023 -> 1.6MiB

@ricardoV94
Copy link
Member

Did you try jitting? Does it change anything for even a single use like this?

@ricardoV94
Copy link
Member

Regarding the best option. We added this because we were getting OOM with everything vmapped.

You may want to check if your OOM is related to jax pre allocating too much? There's a config flag for that

@JasonTam
Copy link
Contributor Author

JasonTam commented Jan 24, 2024

Did you try jitting? Does it change anything for even a single use like this?

I'm not exactly sure which function would be jit compiled. Here I'm trying the whole thing:
(1.7Gib, slightly more than the non-jit fn at 1.6Gib)

def test_jit_scan_vmap_wo_transpose(model, raw_mcmc_samples):
    jax_fn = get_jax_fn(model)

    def raw_fn():
        jax_vfn = jax.vmap(jax_fn)
        _, outs = scan(
            lambda _, x: ((), jax_vfn(*x)),
            (),
            _device_put(raw_mcmc_samples, postprocessing_backend),
        )
        return outs
    jit_fn = jax.jit(raw_fn)
    ret = jit_fn()
Allocation results for src/tests/profiling/test_jax_postproc_profile.py::test_jit_scan_vmap_wo_transpose at the high watermark

         📦 Total memory allocated: 1.7GiB
         📏 Total allocations: 523197
         📊 Histogram of allocation sizes: |▁ █ ▁    |
         🥇 Biggest allocating functions:
                - __call__:/Users/jason/mambaforge/envs/whisk/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149 -> 1.5GiB
                - backend_compile:/Users/jason/mambaforge/envs/whisk/lib/python3.10/site-packages/jax/_src/compiler.py:251 -> 88.0MiB
                - _numpy_array_constant:/Users/jason/mambaforge/envs/whisk/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:252 -> 25.1MiB
                - <lambda>:/Users/jason/mambaforge/envs/whisk/lib/python3.10/site-packages/jax/_src/dispatch.py:164 -> 6.6MiB
                - __call__:/Users/jason/mambaforge/envs/whisk/lib/python3.10/site-packages/pytensor/link/c/basic.py:1767 -> 3.3MiB

Regarding the best option. We added this because we were getting OOM with everything vmapped.

You may want to check if your OOM is related to jax pre allocating too much? There's a config flag for that

Oh, I haven't read into jax's preallocation at all. What's the config flag you're referring to?

@ricardoV94
Copy link
Member

Oh, I haven't read into jax's preallocation at all. What's the config flag you're referring to?

Seems to only matter for GPU https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 25, 2024

(2) showed a smaller footprint (1.8GiB), but still not as small as what's in the PR (scan for chains and vmap for draws -- which used 1.6GiB)

I'm surprised by this but I guess JAX traced-naive transpose+scan just sucks memory wise (the docs of jax say swapaxes may return copies which I guess it is doing in your case).

The problem is that I don't know if this is general. @fonnesbeck could you test if this also fixes the memory problems you were seeing in your model?

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 25, 2024

@JasonTam do you want to try nested scan as well? Since you already did so many permutations :b

That should be the most extreme at tbe opposite side of just vmap

@JasonTam
Copy link
Contributor Author

@JasonTam do you want to try nested scan as well? Since you already did so many permutations :b

That should be the most extreme at tbe opposite side of just vmap

Nested scan is the implementation (2) in

I was curious to see:

  1. What if we switched what scan / vmap were being used for. ie) scan over draws and vmap over chains
  2. What if we used scan for both chains and draws

(2) showed a smaller footprint (1.8GiB), but still not as small as what's in the PR (scan for chains and vmap for draws -- which used 1.6GiB)

where scan_over_chains calls scan_over_draws

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 25, 2024

Thanks @JasonTam I missed it.

Due to the fear of over fitting to one example I would perhaps go for nested Scan? WDYT?

@JasonTam
Copy link
Contributor Author

JasonTam commented Jan 25, 2024

Due to the fear of over fitting to one example I would perhaps go for nested Scan? WDYT?

I too am afraid of my test set-up not generalizing well. I'm going to try to run some more tests. But also, since this is probably a widely used function for most users, I'd consider putting it under another option in postprocessing_vectorize: Literal["vmap", "scan", "nested_scan"], and potentially keeping the default as "scan".

I also think the existing options "vmap" and "scan" are a little misleading since both use vmap over the chain dimension . Feels more like "nested_vmap", "vmap_scan", etc

I would definitely also appreciate feedback from @ferrine, as the previous author these bits

@ricardoV94
Copy link
Member

@JasonTam any news? Would be great to patch this one up :)

@JasonTam
Copy link
Contributor Author

JasonTam commented Feb 6, 2024

@ricardoV94 I haven't had time to play with this unfortunately.
The only news I have is: I'm was not able to reproduce this earlier claim of mine:

But FWIW, this change does allow my large model to finish sampling now with a 30gb memory limit when it was previously OOM under a 48gb limit.

When testing this larger model, some of these methods were failing. But I need to wait my turn on a cluster to make sure there's no interference, so testing these methods has been slow.

@JasonTam
Copy link
Contributor Author

@ricardoV94 here are some results from a larger test:
raw_mcm_samples looks like:

 [
(4, 1000)
(4, 1000)
(4, 1000, 468)
(4, 1000, 7, 468)
(4, 1000, 2, 9, 468)
]

5 variables of (4 chains, 1000 draws, ...)

Tested on an azure k8s cluster with Epdsv5-series vm's (3.0Ghz cpu) where each job has plenty of cpu and memory to spare.

method name method for chains dim method for draws dim Peak Memory [GiB] Total allocations Call Duration [s]
scan_vmap via listcomp transpose (control) vmap scan 28.765 1451658 8.219
scan_vmap_wo_transpose scan vmap 15.938 828079 6.579
vmap_scan vmap scan 28.532 1452441 8.318
nested_scan scan scan 17.903 1277702 8.293
nested_vmap vmap vmap 30.123 1827578 9.175
vmap_map vmap map 28.532 1451486 8.280
nested_xmap xmap xmap 20.703 623790 5.592
looped_vmap python loop vmap 34.933 6474397 20.379

I hope I'm understanding which method goes to which dimension correct. From these results, it does seem like scan_vmap_wo_transpose has the lowest peak memory. (Which is simply removing the transposes via list comp, as seen in this PR)

@ricardoV94
Copy link
Member

I still like the nested scan better because we know that vmap was the source of the problem in the case that first motivated these changes. Unfortunately we don't have a way to retrieve that example but I suspect the current solution in this PR would be a regression there .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: Jax-based samplers crash at transformation stage
2 participants