Skip to content

Commit

Permalink
add isplit algorithm
Browse files Browse the repository at this point in the history
    Get indices to split a sequence into a number of chunks.  This
    algorithm produces nearly equal chunks when the data cannot be
    equally split.

    Based on the algorithm for splitting arrays in numpy.array_split
  • Loading branch information
esheldon committed Oct 28, 2024
1 parent ba4ee98 commit 58f6f64
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 0 deletions.
49 changes: 49 additions & 0 deletions esutil/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,55 @@
"""


def isplit(num, nchunks):
"""
Get indices to split a sequence into a number of chunks. This algorithm
produces nearly equal chunks when the data cannot be equally split.
Based on the algorithm for splitting arrays in numpy.array_split
Parameters
----------
num: int
Number of elements to be split into chunks
nchunks: int
Number of chunks
Returns
-------
subs: array
Array with fields 'start' and 'end for each chunks, so a chunk can
be gotten with
subs = isplit(arr.size, nchunks)
for i in range(nchunks):
achunk = array[subs['start'][i]:subs['end'][i]]
"""
import numpy as np

nchunks = int(nchunks)

if nchunks <= 0:
raise ValueError(f'got nchunks={nchunks} < 0')

neach_section, extras = divmod(num, nchunks)

section_sizes = (
[0] + extras * [neach_section+1]
+ (nchunks-extras) * [neach_section]
)
div_points = np.array(section_sizes, dtype=np.intp).cumsum()

subs = np.zeros(nchunks, dtype=[('start', 'i8'), ('end', 'i8')])

for i in range(nchunks):
subs['start'][i] = div_points[i]
subs['end'][i] = div_points[i + 1]

return subs


def quicksort(data):
"""
Name:
Expand Down
16 changes: 16 additions & 0 deletions esutil/tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,19 @@ def test_quicksort():
s = data_orig.argsort()

assert np.all(data == data_orig[s])


def test_isplit():
num = 135
nchunks = 11

subs = eu.algorithm.isplit(num=num, nchunks=nchunks)
assert subs.size == 11

assert np.all(
subs['start'] == [0, 13, 26, 39, 51, 63, 75, 87, 99, 111, 123]
)

assert np.all(
subs['end'] == [13, 26, 39, 51, 63, 75, 87, 99, 111, 123, 135]
)

0 comments on commit 58f6f64

Please sign in to comment.