Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add population_count primitive to lax #2753

Merged
merged 5 commits into from
Apr 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/jax.lax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Operators
bitwise_and
bitwise_or
bitwise_xor
population_count
broadcast
broadcasted_iota
broadcast_in_dim
Expand Down
6 changes: 6 additions & 0 deletions jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,10 @@ def bitwise_xor(x: Array, y: Array) -> Array:
r"""Elementwise exclusive OR: :math:`x \oplus y`."""
return xor_p.bind(x, y)

def population_count(x: Array) -> Array:
r"""Elementwise popcount, count the number of set bits in each element."""
return population_count_p.bind(x)

def add(x: Array, y: Array) -> Array:
r"""Elementwise addition: :math:`x + y`."""
return add_p.bind(x, y)
Expand Down Expand Up @@ -2023,6 +2027,8 @@ def _pow_jvp_rhs(g, ans, x, y):
xor_p = standard_naryop([_bool_or_int, _bool_or_int], 'xor')
ad.defjvp_zero(xor_p)

population_count_p = standard_unop(_bool_or_int, 'population_count')

def _add_transpose(t, x, y):
# The following linearity assertion is morally true, but because in some cases we
# instantiate zeros for convenience, it doesn't always hold.
Expand Down
25 changes: 25 additions & 0 deletions jax/lax_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,31 @@ def rem(lhs, rhs):
shift_right_arithmetic = onp.right_shift
# TODO shift_right_logical

def population_count(x):
assert x.dtype in (onp.uint32, onp.uint64)
m = [
0x5555555555555555, # binary: 0101...
0x3333333333333333, # binary: 00110011..
0x0f0f0f0f0f0f0f0f, # binary: 4 zeros, 4 ones ...
0x00ff00ff00ff00ff, # binary: 8 zeros, 8 ones ...
0x0000ffff0000ffff, # binary: 16 zeros, 16 ones ...
0x00000000ffffffff, # binary: 32 zeros, 32 ones
]

if x.dtype == onp.uint32:
m = list(map(onp.uint32, m[:-1]))
else:
m = list(map(onp.uint64, m))

x = (x & m[0]) + ((x >> 1) & m[0]) # put count of each 2 bits into those 2 bits
x = (x & m[1]) + ((x >> 2) & m[1]) # put count of each 4 bits into those 4 bits
x = (x & m[2]) + ((x >> 4) & m[2]) # put count of each 8 bits into those 8 bits
x = (x & m[3]) + ((x >> 8) & m[3]) # put count of each 16 bits into those 16 bits
x = (x & m[4]) + ((x >> 16) & m[4]) # put count of each 32 bits into those 32 bits
if x.dtype == onp.uint64:
x = (x & m[5]) + ((x >> 32) & m[5]) # put count of each 64 bits into those 64 bits
return x

eq = onp.equal
ne = onp.not_equal
ge = onp.greater_equal
Expand Down
1 change: 1 addition & 0 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def op_record(op, nargs, dtypes, rng_factory, tol=None):
op_record("bitwise_not", 1, bool_dtypes, jtu.rand_small),
op_record("bitwise_or", 2, bool_dtypes, jtu.rand_small),
op_record("bitwise_xor", 2, bool_dtypes, jtu.rand_small),
op_record("population_count", 1, uint_dtypes, partial(jtu.rand_int, 1 << 32)),

op_record("add", 2, default_dtypes + complex_dtypes, jtu.rand_small),
op_record("sub", 2, default_dtypes + complex_dtypes, jtu.rand_small),
Expand Down