Skip to content

Commit

Permalink
Switch jaxlib Python code to use the lower-level xla.ops API when bui…
Browse files Browse the repository at this point in the history
…lding XLA ops. (#2798)

Change in preparation for deleting xla_client.ComputationBuilder.
  • Loading branch information
hawkinsp authored Apr 22, 2020
1 parent 7334f97 commit 3f8c735
Show file tree
Hide file tree
Showing 3 changed files with 294 additions and 269 deletions.
9 changes: 7 additions & 2 deletions jaxlib/cuda_prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,14 @@

_prod = lambda xs: functools.reduce(operator.mul, xs, 1)

# TODO(phawkins): remove after we no longer need to support old jax releases.
def _unpack_builder(c):
# If `c` is a ComputationBuilder object, extracts the underlying XlaBuilder.
return getattr(c, "_builder", c)

def threefry2x32(c, keys, data):
"""ThreeFry2x32 kernel for GPU."""
c = _unpack_builder(c)
assert len(keys) == 2, keys
assert len(data) == 2, data
dims = c.GetShape(keys[0]).dimensions()
Expand All @@ -46,8 +51,8 @@ def threefry2x32(c, keys, data):
opaque = cuda_prng_kernels.cuda_threefry2x32_descriptor(_prod(dims))
layout = tuple(range(ndims - 1, -1, -1))
shape = xla_client.Shape.array_shape(dtype, dims, layout)
return c.CustomCallWithLayout(
b"cuda_threefry2x32",
return xla_client.ops.CustomCallWithLayout(
c, b"cuda_threefry2x32",
operands=(keys[0], keys[1], data[0], data[1]),
shape_with_layout=xla_client.Shape.tuple_shape([shape, shape]),
operand_shapes_with_layout=(shape,) * 4,
Expand Down
101 changes: 57 additions & 44 deletions jaxlib/cusolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,13 @@
except ImportError:
pass


_ops = xla_client.ops
_Shape = xla_client.Shape

# TODO(phawkins): remove after we no longer need to support old jax releases.
def _unpack_builder(c):
# If `c` is a ComputationBuilder object, extracts the underlying XlaBuilder.
return getattr(c, "_builder", c)

def _real_type(dtype):
"""Returns the real equivalent of 'dtype'."""
Expand All @@ -59,6 +63,7 @@ def trsm(c, a, b, left_side=False, lower=False, trans_a=False, conj_a=False,
XLA implements unbatched triangular solve directly, so we need only implement
the batched case."""
c = _unpack_builder(c)
b_shape = c.GetShape(b)
dtype = b_shape.element_type()
dims = b_shape.dimensions()
Expand All @@ -81,8 +86,8 @@ def trsm(c, a, b, left_side=False, lower=False, trans_a=False, conj_a=False,
lwork, opaque = cublas_kernels.build_trsm_batched_descriptor(
np.dtype(dtype), batch, m, n, left_side, lower, trans_a, conj_a, diag)
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
out = c.CustomCallWithLayout(
b"cublas_trsm_batched",
out = _ops.CustomCallWithLayout(
c, b"cublas_trsm_batched",
operands=(a, b),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(dtype, b_shape.dimensions(), layout),
Expand All @@ -93,11 +98,12 @@ def trsm(c, a, b, left_side=False, lower=False, trans_a=False, conj_a=False,
_Shape.array_shape(dtype, b_shape.dimensions(), layout),
),
opaque=opaque)
return c.GetTupleElement(out, 0)
return _ops.GetTupleElement(out, 0)


def potrf(c, a, lower):
"""Cholesky decomposition."""
c = _unpack_builder(c)
a_shape = c.GetShape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
Expand All @@ -111,8 +117,8 @@ def potrf(c, a, lower):
np.dtype(dtype), lower, batch, n)
kernel = b"cusolver_potrf"

out = c.CustomCallWithLayout(
kernel,
out = _ops.CustomCallWithLayout(
c, kernel,
operands=(a,),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(
Expand All @@ -126,11 +132,12 @@ def potrf(c, a, lower):
dtype, batch_dims + (n, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),),
opaque=opaque)
return c.GetTupleElement(out, 0), c.GetTupleElement(out, 1)
return _ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1)


def getrf(c, a):
"""LU decomposition."""
c = _unpack_builder(c)
a_shape = c.GetShape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
Expand All @@ -151,8 +158,8 @@ def getrf(c, a):
workspace = _Shape.array_shape(dtype, (lwork,), (0,))
kernel = b"cusolver_getrf"

out = c.CustomCallWithLayout(
kernel,
out = _ops.CustomCallWithLayout(
c, kernel,
operands=(a,),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(
Expand All @@ -169,11 +176,12 @@ def getrf(c, a):
dtype, batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),),
opaque=opaque)
return (c.GetTupleElement(out, 0), c.GetTupleElement(out, 1),
c.GetTupleElement(out, 2))
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1),
_ops.GetTupleElement(out, 2))

def geqrf(c, a):
"""QR decomposition."""
c = _unpack_builder(c)
a_shape = c.GetShape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
Expand All @@ -188,8 +196,8 @@ def geqrf(c, a):
workspace = _Shape.array_shape(dtype, (lwork,), (0,))
kernel = b"cusolver_geqrf"

out = c.CustomCallWithLayout(
kernel,
out = _ops.CustomCallWithLayout(
c, kernel,
operands=(a,),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(
Expand All @@ -206,11 +214,12 @@ def geqrf(c, a):
dtype, batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),),
opaque=opaque)
return (c.GetTupleElement(out, 0), c.GetTupleElement(out, 1),
c.GetTupleElement(out, 2))
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1),
_ops.GetTupleElement(out, 2))

def orgqr(c, a, tau):
"""Product of elementary Householder reflections."""
c = _unpack_builder(c)
a_shape = c.GetShape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
Expand All @@ -229,8 +238,8 @@ def orgqr(c, a, tau):
workspace = _Shape.array_shape(dtype, (lwork,), (0,))
kernel = b"cusolver_orgqr"

out = c.CustomCallWithLayout(
kernel,
out = _ops.CustomCallWithLayout(
c, kernel,
operands=(a, tau),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(
Expand All @@ -249,11 +258,12 @@ def orgqr(c, a, tau):
tuple(range(num_bd, -1, -1))),
),
opaque=opaque)
return (c.GetTupleElement(out, 0), c.GetTupleElement(out, 1))
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1))


def syevd(c, a, lower=False):
"""Symmetric (Hermitian) eigendecomposition."""
c = _unpack_builder(c)

a_shape = c.GetShape(a)
dtype = a_shape.element_type()
Expand All @@ -276,8 +286,8 @@ def syevd(c, a, lower=False):
np.dtype(dtype), lower, batch, n)
eigvals_type = _real_type(dtype)

out = c.CustomCallWithLayout(
kernel,
out = _ops.CustomCallWithLayout(
c, kernel,
operands=(a,),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(dtype, dims, layout),
Expand All @@ -293,12 +303,13 @@ def syevd(c, a, lower=False):
_Shape.array_shape(dtype, dims, layout),
),
opaque=opaque)
return (c.GetTupleElement(out, 0), c.GetTupleElement(out, 1),
c.GetTupleElement(out, 2))
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1),
_ops.GetTupleElement(out, 2))


def gesvd(c, a, full_matrices=True, compute_uv=True):
"""Singular value decomposition."""
c = _unpack_builder(c)

a_shape = c.GetShape(a)
dims = a_shape.dimensions()
Expand All @@ -316,8 +327,8 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
scalar_layout = tuple(range(num_bd - 1, -1, -1))
vector_layout = (num_bd,) + scalar_layout
matrix_layout = (num_bd, num_bd + 1) + scalar_layout
out = c.CustomCallWithLayout(
b"cusolver_gesvdj",
out = _ops.CustomCallWithLayout(
c, b"cusolver_gesvdj",
operands=(a,),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
Expand All @@ -332,21 +343,21 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
_Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
),
opaque=opaque)
s = c.GetTupleElement(out, 1)
u = c.GetTupleElement(out, 2)
v = c.GetTupleElement(out, 3)
info = c.GetTupleElement(out, 4)
vt = c.Transpose(v, tuple(range(num_bd)) + (num_bd + 1, num_bd))
s = _ops.GetTupleElement(out, 1)
u = _ops.GetTupleElement(out, 2)
v = _ops.GetTupleElement(out, 3)
info = _ops.GetTupleElement(out, 4)
vt = _ops.Transpose(v, tuple(range(num_bd)) + (num_bd + 1, num_bd))
if np.issubdtype(dtype, np.complexfloating):
vt = c.Conj(vt)
vt = _ops.Conj(vt)
elif m < n:
lwork, opaque = cusolver_kernels.build_gesvd_descriptor(
np.dtype(dtype), b, n, m, compute_uv, full_matrices)
scalar_layout = tuple(range(num_bd - 1, -1, -1))
vector_layout = (num_bd,) + scalar_layout
matrix_layout = (num_bd + 1, num_bd) + scalar_layout
out = c.CustomCallWithLayout(
b"cusolver_gesvd",
out = _ops.CustomCallWithLayout(
c, b"cusolver_gesvd",
operands=(a,),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
Expand All @@ -361,19 +372,19 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
_Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
),
opaque=opaque)
s = c.GetTupleElement(out, 1)
vt = c.GetTupleElement(out, 2)
u = c.GetTupleElement(out, 3)
info = c.GetTupleElement(out, 4)
s = _ops.GetTupleElement(out, 1)
vt = _ops.GetTupleElement(out, 2)
u = _ops.GetTupleElement(out, 3)
info = _ops.GetTupleElement(out, 4)
else:
lwork, opaque = cusolver_kernels.build_gesvd_descriptor(
np.dtype(dtype), b, m, n, compute_uv, full_matrices)

scalar_layout = tuple(range(num_bd - 1, -1, -1))
vector_layout = (num_bd,) + scalar_layout
matrix_layout = (num_bd, num_bd + 1) + scalar_layout
out = c.CustomCallWithLayout(
b"cusolver_gesvd",
out = _ops.CustomCallWithLayout(
c, b"cusolver_gesvd",
operands=(a,),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
Expand All @@ -388,11 +399,13 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
_Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
),
opaque=opaque)
s = c.GetTupleElement(out, 1)
u = c.GetTupleElement(out, 2)
vt = c.GetTupleElement(out, 3)
info = c.GetTupleElement(out, 4)
s = _ops.GetTupleElement(out, 1)
u = _ops.GetTupleElement(out, 2)
vt = _ops.GetTupleElement(out, 3)
info = _ops.GetTupleElement(out, 4)
if not full_matrices:
u = c.Slice(u, (0,) * len(dims), batch_dims + (m, min(m, n)))
vt = c.Slice(vt, (0,) * len(dims), batch_dims + (min(m, n), n))
u = _ops.Slice(u, (0,) * len(dims), batch_dims + (m, min(m, n)),
(1,) * len(dims))
vt = _ops.Slice(vt, (0,) * len(dims), batch_dims + (min(m, n), n),
(1,) * len(dims))
return s, u, vt, info
Loading

0 comments on commit 3f8c735

Please sign in to comment.