Skip to content

Commit

Permalink
Merge pull request #156 from oesteban/fix/auto-lowmem-mode
Browse files Browse the repository at this point in the history
ENH: Add a memory check to dynamically limit interpolation blocksize
  • Loading branch information
oesteban authored Dec 14, 2020
2 parents 92350a2 + c13fde3 commit 5735be8
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 1 deletion.
1 change: 0 additions & 1 deletion .github/workflows/travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ jobs:
- name: Install minimal dependencies
run: |
$CONDA/bin/pip install -r min-requirements.txt
$CONDA/bin/pip install .
$CONDA/bin/pip install .[tests]
- uses: actions/cache@v2
Expand Down
12 changes: 12 additions & 0 deletions sdcflows/interfaces/bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ class Coefficients2Warp(SimpleInterface):
output_spec = _Coefficients2WarpOutputSpec

def _run_interface(self, runtime):
from ..utils.misc import get_free_mem

# Calculate the physical coordinates of target grid
targetnii = nb.load(self.inputs.in_target)
targetaff = targetnii.affine
Expand All @@ -237,11 +239,21 @@ def _run_interface(self, runtime):
weights = []
coeffs = []
blocksize = LOW_MEM_BLOCK_SIZE if self.inputs.low_mem else len(points)

for cname in self.inputs.in_coeff:
cnii = nb.load(cname)
cdata = cnii.get_fdata(dtype="float32")
coeffs.append(cdata.reshape(-1))

# Try to probe the free memory
_free_mem = get_free_mem()
suggested_blocksize = (
int(np.round((_free_mem * 0.80) / (3 * 32 * cdata.size)))
if _free_mem
else blocksize
)
blocksize = min(blocksize, suggested_blocksize)

idx = 0
block_w = []
while True:
Expand Down
10 changes: 10 additions & 0 deletions sdcflows/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,13 @@ def last(inlist):
if isinstance(inlist, (list, tuple)):
return inlist[-1]
return inlist


def get_free_mem():
"""Probe the free memory right now."""
try:
from psutil import virtual_memory

return round(virtual_memory().free, 1)
except Exception:
return None
21 changes: 21 additions & 0 deletions sdcflows/utils/tests/test_misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Test miscellaneous utilities."""
import sys
from collections import namedtuple
import types
import pytest
from ..misc import get_free_mem


@pytest.mark.parametrize("retval", [None, 10])
def test_get_free_mem(monkeypatch, retval):
"""Test the get_free_mem utility."""

def mock_func():
if retval is None:
raise ImportError
return namedtuple("Mem", ("free",))(free=retval)

psutil = types.ModuleType("psutil")
psutil.virtual_memory = mock_func
monkeypatch.setitem(sys.modules, "psutil", psutil)
assert get_free_mem() == retval
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ doc =
sphinxcontrib-versioning
docs =
%(doc)s
mem =
psutil
tests =
pytest
pytest-xdist >= 2.0
Expand All @@ -64,6 +66,7 @@ tests =
coverage
all =
%(doc)s
%(mem)s
%(tests)s

[options.package_data]
Expand Down

0 comments on commit 5735be8

Please sign in to comment.