Skip to content

Commit

Permalink
Add population_count to lax_reference
Browse files Browse the repository at this point in the history
  • Loading branch information
j-towns committed Apr 23, 2020
1 parent dae04e9 commit 322a04e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
25 changes: 25 additions & 0 deletions jax/lax_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,31 @@ def rem(lhs, rhs):
shift_right_arithmetic = onp.right_shift
# TODO shift_right_logical

def population_count(x):
assert x.dtype in (onp.uint32, onp.uint64)
m = [
0x5555555555555555, # binary: 0101...
0x3333333333333333, # binary: 00110011..
0x0f0f0f0f0f0f0f0f, # binary: 4 zeros, 4 ones ...
0x00ff00ff00ff00ff, # binary: 8 zeros, 8 ones ...
0x0000ffff0000ffff, # binary: 16 zeros, 16 ones ...
0x00000000ffffffff, # binary: 32 zeros, 32 ones
]

if x.dtype == onp.uint32:
m = list(map(onp.uint32, m[:-1]))
else:
m = list(map(onp.uint64, m))

x = (x & m[0]) + ((x >> 1) & m[0]) # put count of each 2 bits into those 2 bits
x = (x & m[1]) + ((x >> 2) & m[1]) # put count of each 4 bits into those 4 bits
x = (x & m[2]) + ((x >> 4) & m[2]) # put count of each 8 bits into those 8 bits
x = (x & m[3]) + ((x >> 8) & m[3]) # put count of each 16 bits into those 16 bits
x = (x & m[4]) + ((x >> 16) & m[4]) # put count of each 32 bits into those 32 bits
if x.dtype == onp.uint64:
x = (x & m[5]) + ((x >> 32) & m[5]) # put count of each 64 bits into those 64 bits
return x

eq = onp.equal
ne = onp.not_equal
ge = onp.greater_equal
Expand Down
2 changes: 1 addition & 1 deletion tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def op_record(op, nargs, dtypes, rng_factory, tol=None):
op_record("bitwise_not", 1, bool_dtypes, jtu.rand_small),
op_record("bitwise_or", 2, bool_dtypes, jtu.rand_small),
op_record("bitwise_xor", 2, bool_dtypes, jtu.rand_small),
op_record("population_count", 1, bool_dtypes, jtu.rand_small),
op_record("population_count", 1, uint_dtypes, jtu.rand_small),

op_record("add", 2, default_dtypes + complex_dtypes, jtu.rand_small),
op_record("sub", 2, default_dtypes + complex_dtypes, jtu.rand_small),
Expand Down

0 comments on commit 322a04e

Please sign in to comment.