Skip to content

Commit

Permalink
Merge pull request #77 from constantinpape/threadctrl
Browse files Browse the repository at this point in the history
Use threadctl to set numpy thread limits in scope
  • Loading branch information
constantinpape authored Jul 22, 2023
2 parents d0718bb + 090da21 commit d7b8f6c
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 70 deletions.
8 changes: 6 additions & 2 deletions elf/parallel/label.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# IMPORTANT do threadctl import first (before numpy imports)
from threadpoolctl import threadpool_limits

import multiprocessing
# would be nice to use dask, so that we can also run this on the cluster
from concurrent import futures
Expand All @@ -10,8 +13,6 @@
import nifty.ufd as nufd
from .common import get_blocking

from elf.util import set_numpy_threads
set_numpy_threads(1)
import numpy as np


Expand All @@ -20,6 +21,7 @@ def cc_blocks(data, out, mask, blocking, with_background,
n_blocks = blocking.numberOfBlocks

# compute the connected component for one block
@threadpool_limits.wrap(limits=1) # restrict the numpy threadpool to 1 to avoid oversubscription
def _cc_block(block_id):
block = blocking.getBlock(block_id)
bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
Expand Down Expand Up @@ -65,6 +67,7 @@ def merge_blocks(data, out, mask, offsets,
n_blocks = blocking.numberOfBlocks
ndim = out.ndim

@threadpool_limits.wrap(limits=1) # restrict the numpy threadpool to 1 to avoid oversubscription
def _merge_block_faces(block_id):
block = blocking.getBlock(block_id)
offset_block = offsets[block_id]
Expand Down Expand Up @@ -165,6 +168,7 @@ def write_mapping(out, mask, offsets, mapping,
n_blocks = blocking.numberOfBlocks

# compute the connected component for one block
@threadpool_limits.wrap(limits=1) # restrict the numpy threadpool to 1 to avoid oversubscription
def _write_block(block_id):
block = blocking.getBlock(block_id)
bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
Expand Down
26 changes: 13 additions & 13 deletions elf/parallel/operations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# IMPORTANT do threadctl import first (before numpy imports)
from threadpoolctl import threadpool_limits

import multiprocessing
# would be nice to use dask for all of this instead of concurrent.futures
# so that this could be used on a cluster as well
Expand All @@ -7,8 +10,6 @@
from tqdm import tqdm

from .common import get_blocking
from ..util import set_numpy_threads
set_numpy_threads(1)
import numpy as np


Expand Down Expand Up @@ -59,14 +60,15 @@ def isin(x, y, out=None,
blocking = get_blocking(x, block_shape, roi)
n_blocks = blocking.numberOfBlocks

@threadpool_limits.wrap(limits=1) # restrict the numpy threadpool to 1 to avoid oversubscription
def _isin(block_id):
block = blocking.getBlock(block_id)
bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))

# check if we have a mask and if we do if we
# have pixels in the mask
if mask is not None:
m = mask[bb].astype('bool')
m = mask[bb].astype("bool")
if m.sum() == 0:
return None

Expand Down Expand Up @@ -142,14 +144,15 @@ def apply_operation(x, y, operation, out=None,
blocking = get_blocking(x, block_shape, roi)
n_blocks = blocking.numberOfBlocks

@threadpool_limits.wrap(limits=1) # restrict the numpy threadpool to 1 to avoid oversubscription
def _apply_scalar(block_id):
block = blocking.getBlock(block_id)
bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))

# check if we have a mask and if we do if we
# have pixels in the mask
if mask is not None:
m = mask[bb].astype('bool')
m = mask[bb].astype("bool")
if m.sum() == 0:
return None

Expand All @@ -173,7 +176,7 @@ def _apply_array(block_id):
# check if we have a mask and if we do if we
# have pixels in the mask
if mask is not None:
m = mask[bb].astype('bool')
m = mask[bb].astype("bool")
if m.sum() == 0:
return None

Expand Down Expand Up @@ -238,14 +241,15 @@ def apply_operation_single(x, operation, axis=None, out=None,
blocking = get_blocking(out, block_shape, roi)
n_blocks = blocking.numberOfBlocks

@threadpool_limits.wrap(limits=1) # restrict the numpy threadpool to 1 to avoid oversubscription
def _apply(block_id):
block = blocking.getBlock(block_id)
bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))

# check if we have a mask and if we do if we
# have pixels in the mask
if mask is not None:
m = mask[bb].astype('bool')
m = mask[bb].astype("bool")
if m.sum() == 0:
return None

Expand Down Expand Up @@ -302,17 +306,13 @@ def op(x, y, out=None, block_shape=None, n_threads=None,


# autogenerate parallel implementation for common numpy operations
_op_names = ['add', 'subtract', 'multiply', 'divide',
'greater', 'greater_equal', 'less', 'less_equal',
'minimum', 'maximum']
_op_names = ["add", "subtract", "multiply", "divide",
"greater", "greater_equal", "less", "less_equal",
"minimum", "maximum"]


for op_name in _op_names:
_generate_operation(op_name)

del _generate_operation
del _op_names


# TODO autogenerate parallel implementation for common single operand numpy operations
# _op_nams = ['mean', 'max', 'min', 'std']
8 changes: 6 additions & 2 deletions elf/parallel/relabel.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
# IMPORTANT do threadctl import first (before numpy imports)
from threadpoolctl import threadpool_limits

import multiprocessing

# would be nice to use dask, so that we can also run this on the cluster
from concurrent import futures
from tqdm import tqdm
import nifty.tools as nt

from .unique import unique
from .common import get_blocking
from ..util import set_numpy_threads
set_numpy_threads(1)

import numpy as np


Expand Down Expand Up @@ -51,6 +54,7 @@ def relabel_consecutive(data, start_label=0, keep_zeros=True, out=None,
raise ValueError("Expect data and out of same shape, got %s and %s" % (str(data.shape),
str(out.shape)))

@threadpool_limits.wrap(limits=1) # restrict the numpy threadpool to 1 to avoid oversubscription
def _relabel(block_id):
block = blocking.getBlock(block_id)
bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
Expand Down
8 changes: 6 additions & 2 deletions elf/parallel/size_filter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# IMPORTANT do threadctl import first (before numpy imports)
from threadpoolctl import threadpool_limits

import multiprocessing
# would be nice to use dask, so that we can also run this on the cluster
from concurrent import futures
Expand All @@ -7,8 +10,7 @@

from .common import get_blocking
from .unique import unique
from ..util import set_numpy_threads
set_numpy_threads(1)

import numpy as np


Expand Down Expand Up @@ -85,6 +87,7 @@ def size_filter(data, out, min_size=None, max_size=None,
if 0 in mapping:
assert mapping[0] == 0

@threadpool_limits.wrap(limits=1) # restrict the numpy threadpool to 1 to avoid oversubscription
def _relabel(seg, block_mask):
if block_mask is None or block_mask.sum() == block_mask.size:
ids_in_block = np.unique(seg)
Expand All @@ -101,6 +104,7 @@ def _relabel(seg, block_mask):
else:
_relabel = None

@threadpool_limits.wrap(limits=1) # restrict the numpy threadpool to 1 to avoid oversubscription
def filter_function(block_seg, block_mask):
bg_mask = np.isin(block_seg, filter_ids)
if block_mask is not None:
Expand Down
9 changes: 7 additions & 2 deletions elf/parallel/stats.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# IMPORTANT do threadctl import first (before numpy imports)
from threadpoolctl import threadpool_limits

import multiprocessing
# would be nice to use dask for all of this instead of concurrent.futures
# so that this could be used on a cluster as well
from concurrent import futures
from tqdm import tqdm

from .common import get_blocking
from ..util import set_numpy_threads
set_numpy_threads(1)

import numpy as np


Expand All @@ -29,6 +31,7 @@ def mean(data, block_shape=None, n_threads=None, mask=None, verbose=False, roi=N
blocking = get_blocking(data, block_shape, roi)
n_blocks = blocking.numberOfBlocks

@threadpool_limits.wrap(limits=1) # restrict the numpy threadpool to 1 to avoid oversubscription
def _mean(block_id):
block = blocking.getBlock(block_id)
bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
Expand Down Expand Up @@ -77,6 +80,7 @@ def mean_and_std(data, block_shape=None, n_threads=None, mask=None, verbose=Fals
blocking = get_blocking(data, block_shape, roi)
n_blocks = blocking.numberOfBlocks

@threadpool_limits.wrap(limits=1) # restrict the numpy threadpool to 1 to avoid oversubscription
def _mean_and_std(block_id):
block = blocking.getBlock(block_id)
bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
Expand Down Expand Up @@ -150,6 +154,7 @@ def min_and_max(data, block_shape=None, n_threads=None, mask=None, verbose=False
blocking = get_blocking(data, block_shape, roi)
n_blocks = blocking.numberOfBlocks

@threadpool_limits.wrap(limits=1) # restrict the numpy threadpool to 1 to avoid oversubscription
def _min_and_max(block_id):
block = blocking.getBlock(block_id)
bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
Expand Down
7 changes: 5 additions & 2 deletions elf/parallel/unique.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# IMPORTANT do threadctl import first (before numpy imports)
from threadpoolctl import threadpool_limits

import multiprocessing
# would be nice to use dask, so that we can also run this on the cluster
from concurrent import futures
from tqdm import tqdm

from .common import get_blocking
from ..util import set_numpy_threads
set_numpy_threads(1)

import numpy as np


Expand All @@ -31,6 +33,7 @@ def unique(data, return_counts=False, block_shape=None, n_threads=None,
blocking = get_blocking(data, block_shape, roi)
n_blocks = blocking.numberOfBlocks

@threadpool_limits.wrap(limits=1) # restrict the numpy threadpool to 1 to avoid oversubscription
def _unique(block_id):
block = blocking.getBlock(block_id)
bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
Expand Down
26 changes: 17 additions & 9 deletions elf/segmentation/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ..util import set_numpy_threads
set_numpy_threads(1)
# IMPORTANT do threadctl import first (before numpy imports)
from threadpoolctl import threadpool_limits

import numpy as np
import vigra
try:
Expand Down Expand Up @@ -147,16 +148,22 @@ def _cluster(embeddings, clustering_alg, semantic_mask=None, remove_largest=Fals

def segment_hdbscan(embeddings, min_size, eps, remove_largest, n_jobs=1):
assert hdbscan is not None, "Needs hdbscan library"
clustering = hdbscan.HDBSCAN(min_cluster_size=min_size, cluster_selection_epsilon=eps, core_dist_n_jobs=n_jobs)
return _cluster(embeddings, clustering, remove_largest=remove_largest).astype("uint64")
with threadpool_limits(limits=n_jobs):
clustering = hdbscan.HDBSCAN(
min_cluster_size=min_size, cluster_selection_epsilon=eps, core_dist_n_jobs=n_jobs
)
result = _cluster(embeddings, clustering, remove_largest=remove_largest).astype("uint64")
return result


def segment_mean_shift(embeddings, bandwidth, n_jobs=1):
clustering = MeanShift(bandwidth=bandwidth, bin_seeding=True, n_jobs=n_jobs)
return _cluster(embeddings, clustering).astype("uint64")
with threadpool_limits(limits=n_jobs):
clustering = MeanShift(bandwidth=bandwidth, bin_seeding=True, n_jobs=n_jobs)
result = _cluster(embeddings, clustering).astype("uint64")
return result


def segment_consistency(embeddings1, embeddings2, bandwidth, iou_threshold, num_anchors, skip_zero=True):
def segment_consistency(embeddings1, embeddings2, bandwidth, iou_threshold, num_anchors, skip_zero=True, n_jobs=1):
def _iou(gt, seg):
epsilon = 1e-5
inter = (gt & seg).sum()
Expand All @@ -165,8 +172,9 @@ def _iou(gt, seg):
iou = (inter + epsilon) / (union + epsilon)
return iou

clustering = MeanShift(bandwidth=bandwidth, bin_seeding=True)
clusters = _cluster(embeddings1, clustering)
with threadpool_limits(limits=n_jobs):
clustering = MeanShift(bandwidth=bandwidth, bin_seeding=True, n_jobs=n_jobs)
clusters = _cluster(embeddings1, clustering)

for label_id in np.unique(clusters):
if label_id == 0 and skip_zero:
Expand Down
43 changes: 5 additions & 38 deletions elf/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import ctypes
import numbers
import os
from math import ceil
from itertools import product

Expand All @@ -9,7 +7,7 @@ def slice_to_start_stop(s, size):
"""For a single dimension with a given size, normalize slice to size.
Returns slice(None, 0) if slice is invalid."""
if s.step not in (None, 1):
raise ValueError('Nontrivial steps are not supported')
raise ValueError("Nontrivial steps are not supported")

if s.start is None:
start = 0
Expand Down Expand Up @@ -39,15 +37,15 @@ def int_to_start_stop(i, size):
if -size < i < 0:
start = i + size
elif i >= size or i < -size:
raise ValueError('Index ({}) out of range (0-{})'.format(i, size - 1))
raise ValueError("Index ({}) out of range (0-{})".format(i, size - 1))
else:
start = i
return slice(start, start + 1)


# For now, I have copied the z5 implementation:
# https://github.com/constantinpape/z5/blob/master/src/python/module/z5py/shape_utils.py#L126
# But it's worth taking a look at @clbarnes more general implementation too
# But it"s worth taking a look at @clbarnes more general implementation too
# https://github.com/clbarnes/h5py_like
def normalize_index(index, shape):
""" Normalize index to shape.
Expand All @@ -62,8 +60,8 @@ def normalize_index(index, shape):
tuple[slice]: normalized slices (start and stop are both non-None)
tuple[int]: which singleton dimensions should be squeezed out
"""
type_msg = 'Advanced selection inappropriate. ' \
'Only numbers, slices (`:`), and ellipsis (`...`) are valid indices (or tuples thereof)'
type_msg = "Advanced selection inappropriate. " \
"Only numbers, slices (`:`), and ellipsis (`...`) are valid indices (or tuples thereof)"

if isinstance(index, tuple):
slices_lst = list(index)
Expand Down Expand Up @@ -203,37 +201,6 @@ def downscale_shape(shape, scale_factor, ceil_mode=True):
return tuple(sh // sf for sh, sf in zip(shape, scale_))


def set_numpy_threads(n_threads):
""" Set the number of threads numpy exposes to its
underlying linalg library.
This needs to be called BEFORE the numpy import and sets the number
of threads statically.
Based on answers in https://github.com/numpy/numpy/issues/11826.
"""

# set number of threads for mkl if it is used
try:
import mkl
mkl.set_num_threaads(n_threads)
except Exception:
pass

for name in ['libmkl_rt.so', 'libmkl_rt.dylib', 'mkl_Rt.dll']:
try:
mkl_rt = ctypes.CDLL(name)
mkl_rt.mkl_set_num_threads(ctypes.byref(ctypes.c_int(n_threads)))
except Exception:
pass

# set number of threads in all possibly relevant environment variables
os.environ['OMP_NUM_THREADS'] = str(n_threads)
os.environ['OPENBLAS_NUM_THREADS'] = str(n_threads)
os.environ['MKL_NUM_THREADS'] = str(n_threads)
os.environ['VECLIB_NUM_THREADS'] = str(n_threads)
os.environ['NUMEXPR_NUM_THREADS'] = str(n_threads)


def sigma_to_halo(sigma, order):
""" Compute the halo value to apply filter in parallel.
Expand Down
Loading

0 comments on commit d7b8f6c

Please sign in to comment.