diff --git a/.github/workflows/travis.yml b/.github/workflows/travis.yml index e3b9d002c6..e3f431957d 100644 --- a/.github/workflows/travis.yml +++ b/.github/workflows/travis.yml @@ -50,7 +50,6 @@ jobs: - name: Install minimal dependencies run: | $CONDA/bin/pip install -r min-requirements.txt - $CONDA/bin/pip install . $CONDA/bin/pip install .[tests] - uses: actions/cache@v2 diff --git a/sdcflows/interfaces/bspline.py b/sdcflows/interfaces/bspline.py index 483e8d301f..2ccc15ae61 100644 --- a/sdcflows/interfaces/bspline.py +++ b/sdcflows/interfaces/bspline.py @@ -227,6 +227,8 @@ class Coefficients2Warp(SimpleInterface): output_spec = _Coefficients2WarpOutputSpec def _run_interface(self, runtime): + from ..utils.misc import get_free_mem + # Calculate the physical coordinates of target grid targetnii = nb.load(self.inputs.in_target) targetaff = targetnii.affine @@ -237,11 +239,21 @@ def _run_interface(self, runtime): weights = [] coeffs = [] blocksize = LOW_MEM_BLOCK_SIZE if self.inputs.low_mem else len(points) + for cname in self.inputs.in_coeff: cnii = nb.load(cname) cdata = cnii.get_fdata(dtype="float32") coeffs.append(cdata.reshape(-1)) + # Try to probe the free memory + _free_mem = get_free_mem() + suggested_blocksize = ( + int(np.round((_free_mem * 0.80) / (3 * 32 * cdata.size))) + if _free_mem + else blocksize + ) + blocksize = min(blocksize, suggested_blocksize) + idx = 0 block_w = [] while True: diff --git a/sdcflows/utils/misc.py b/sdcflows/utils/misc.py index 0fc9a675da..3d5d390c2f 100644 --- a/sdcflows/utils/misc.py +++ b/sdcflows/utils/misc.py @@ -35,3 +35,13 @@ def last(inlist): if isinstance(inlist, (list, tuple)): return inlist[-1] return inlist + + +def get_free_mem(): + """Probe the free memory right now.""" + try: + from psutil import virtual_memory + + return round(virtual_memory().free, 1) + except Exception: + return None diff --git a/sdcflows/utils/tests/test_misc.py b/sdcflows/utils/tests/test_misc.py new file mode 100644 index 0000000000..da7a3d9f09 --- /dev/null +++ b/sdcflows/utils/tests/test_misc.py @@ -0,0 +1,21 @@ +"""Test miscellaneous utilities.""" +import sys +from collections import namedtuple +import types +import pytest +from ..misc import get_free_mem + + +@pytest.mark.parametrize("retval", [None, 10]) +def test_get_free_mem(monkeypatch, retval): + """Test the get_free_mem utility.""" + + def mock_func(): + if retval is None: + raise ImportError + return namedtuple("Mem", ("free",))(free=retval) + + psutil = types.ModuleType("psutil") + psutil.virtual_memory = mock_func + monkeypatch.setitem(sys.modules, "psutil", psutil) + assert get_free_mem() == retval diff --git a/setup.cfg b/setup.cfg index 446767a327..033a262473 100644 --- a/setup.cfg +++ b/setup.cfg @@ -56,6 +56,8 @@ doc = sphinxcontrib-versioning docs = %(doc)s +mem = + psutil tests = pytest pytest-xdist >= 2.0 @@ -64,6 +66,7 @@ tests = coverage all = %(doc)s + %(mem)s %(tests)s [options.package_data]