Skip to content

Commit

Permalink
Merge pull request #379 from QuantEcon/k_array
Browse files Browse the repository at this point in the history
Re-implement `next_k_array`; add `k_array_rank`
  • Loading branch information
mmcky authored Jan 5, 2018
2 parents 475650b + a215f1f commit e61702f
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 83 deletions.
86 changes: 3 additions & 83 deletions quantecon/game_theory/support_enumeration.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
from numba import jit
from ..util.numba import _numba_linalg_solve
from ..util.combinatorics import next_k_array


def support_enumeration(g):
Expand Down Expand Up @@ -106,8 +107,8 @@ def _support_enumeration_gen(payoff_matrix0, payoff_matrix1):
actions)):
out[p][supp] = action[:-1]
yield out
_next_k_array(supps[1])
_next_k_array(supps[0])
next_k_array(supps[1])
next_k_array(supps[0])


@jit(nopython=True, cache=True)
Expand Down Expand Up @@ -180,84 +181,3 @@ def _indiff_mixed_action(payoff_matrix, own_supp, opp_supp, A, out):
if payoff > val:
return False
return True


@jit(nopython=True, cache=True)
def _next_k_combination(x):
"""
Find the next k-combination, as described by an integer in binary
representation with the k set bits, by "Gosper's hack".
Copy-paste from en.wikipedia.org/wiki/Combinatorial_number_system
Parameters
----------
x : int
Integer with k set bits.
Returns
-------
int
Smallest integer > x with k set bits.
"""
u = x & -x
v = u + x
return v + (((v ^ x) // u) >> 2)


@jit(nopython=True, cache=True)
def _next_k_array(a):
"""
Given an array `a` of k distinct nonnegative integers, return the
next k-array in lexicographic ordering of the descending sequences
of the elements. `a` is modified in place.
Parameters
----------
a : ndarray(int, ndim=1)
Array of length k.
Returns
-------
a : ndarray(int, ndim=1)
View of `a`.
Examples
--------
Enumerate all the subsets with k elements of the set {0, ..., n-1}.
>>> n, k = 4, 2
>>> a = np.arange(k)
>>> while a[-1] < n:
... print(a)
... a = _next_k_array(a)
...
[0 1]
[0 2]
[1 2]
[0 3]
[1 3]
[2 3]
"""
k = len(a)
if k == 0:
return a

x = 0
for i in range(k):
x += (1 << a[i])

x = _next_k_combination(x)

pos = 0
for i in range(k):
while x & 1 == 0:
x = x >> 1
pos += 1
a[i] = pos
x = x >> 1
pos += 1

return a
122 changes: 122 additions & 0 deletions quantecon/util/combinatorics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""
Useful routines for combinatorics
"""
from scipy.special import comb
from numba import jit

from .numba import comb_jit


@jit(nopython=True, cache=True)
def next_k_array(a):
"""
Given an array `a` of k distinct nonnegative integers, sorted in
ascending order, return the next k-array in the lexicographic
ordering of the descending sequences of the elements [1]_. `a` is
modified in place.
Parameters
----------
a : ndarray(int, ndim=1)
Array of length k.
Returns
-------
a : ndarray(int, ndim=1)
View of `a`.
Examples
--------
Enumerate all the subsets with k elements of the set {0, ..., n-1}.
>>> n, k = 4, 2
>>> a = np.arange(k)
>>> while a[-1] < n:
... print(a)
... a = next_k_array(a)
...
[0 1]
[0 2]
[1 2]
[0 3]
[1 3]
[2 3]
References
----------
.. [1] `Combinatorial number system
<https://en.wikipedia.org/wiki/Combinatorial_number_system>`_,
Wikipedia.
"""
# Logic taken from Algotirhm T in D. Knuth, The Art of Computer
# Programming, Section 7.2.1.3 "Generating All Combinations".
k = len(a)
if k == 1 or a[0] + 1 < a[1]:
a[0] += 1
return a

a[0] = 0
i = 1
x = a[i] + 1

while i < k-1 and x == a[i+1]:
i += 1
a[i-1] = i - 1
x = a[i] + 1
a[i] = x

return a


def k_array_rank(a):
"""
Given an array `a` of k distinct nonnegative integers, sorted in
ascending order, return its ranking in the lexicographic ordering of
the descending sequences of the elements [1]_.
Parameters
----------
a : ndarray(int, ndim=1)
Array of length k.
Returns
-------
idx : scalar(int)
Ranking of `a`.
References
----------
.. [1] `Combinatorial number system
<https://en.wikipedia.org/wiki/Combinatorial_number_system>`_,
Wikipedia.
"""
k = len(a)
idx = int(a[0]) # Convert to Python int
for i in range(1, k):
idx += comb(a[i], i+1, exact=True)
return idx


@jit(nopython=True, cache=True)
def k_array_rank_jit(a):
"""
Numba jit version of `k_array_rank`.
Notes
-----
An incorrect value will be returned without warning or error if
overflow occurs during the computation. It is the user's
responsibility to ensure that the rank of the input array fits
within the range of possible values of `np.intp`; a sufficient
condition for it is `scipy.special.comb(a[-1]+1, len(a), exact=True)
<= np.iinfo(np.intp).max`.
"""
k = len(a)
idx = a[0]
for i in range(1, k):
idx += comb_jit(a[i], i+1)
return idx
70 changes: 70 additions & 0 deletions quantecon/util/tests/test_combinatorics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""
Tests for util/combinatorics.py
"""
import numpy as np
from numpy.testing import assert_array_equal
from nose.tools import eq_
import scipy.special
from quantecon.util.combinatorics import (
next_k_array, k_array_rank, k_array_rank_jit
)


class TestKArray:
def setUp(self):
self.k_arrays = np.array(
[[0, 1, 2],
[0, 1, 3],
[0, 2, 3],
[1, 2, 3],
[0, 1, 4],
[0, 2, 4],
[1, 2, 4],
[0, 3, 4],
[1, 3, 4],
[2, 3, 4],
[0, 1, 5],
[0, 2, 5],
[1, 2, 5],
[0, 3, 5],
[1, 3, 5],
[2, 3, 5],
[0, 4, 5],
[1, 4, 5],
[2, 4, 5],
[3, 4, 5]]
)
self.L, self.k = self.k_arrays.shape

def test_next_k_array(self):
k_arrays_computed = np.empty((self.L, self.k), dtype=int)
k_arrays_computed[0] = np.arange(self.k)
for i in range(1, self.L):
k_arrays_computed[i] = k_arrays_computed[i-1]
next_k_array(k_arrays_computed[i])
assert_array_equal(k_arrays_computed, self.k_arrays)

def test_k_array_rank(self):
for i in range(self.L):
eq_(k_array_rank(self.k_arrays[i]), i)

def test_k_array_rank_jit(self):
for i in range(self.L):
eq_(k_array_rank_jit(self.k_arrays[i]), i)


def test_k_array_rank_arbitrary_precision():
n, k = 100, 50
a = np.arange(n-k, n)
eq_(k_array_rank(a), scipy.special.comb(n, k, exact=True)-1)


if __name__ == '__main__':
import sys
import nose

argv = sys.argv[:]
argv.append('--verbose')
argv.append('--nocapture')
nose.main(argv=argv, defaultTest=__file__)

0 comments on commit e61702f

Please sign in to comment.