Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor update to the new MLIR specification for custom ops #8

Merged
merged 5 commits into from
Jan 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/action/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
FROM nvidia/cuda:10.2-devel-ubuntu18.04
FROM nvidia/cuda:11.8.0-devel-ubuntu20.04

RUN apt-get update && \
DEBIAN_FRONTEND=noninteractive apt-get install -y git python3-pip
DEBIAN_FRONTEND=noninteractive apt-get install -y git python3-pip cmake

RUN python3 -m pip install -U pip && \
python3 -m pip install -U jax jaxlib==0.1.57+cuda102 -f https://storage.googleapis.com/jax-releases/jax_releases.html
RUN pip install --upgrade pip && \
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

COPY entrypoint.sh /entrypoint.sh

Expand Down
133 changes: 70 additions & 63 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -361,83 +361,81 @@ for _name, _value in cpu_ops.registrations().items():
xla_client.register_cpu_custom_call_target(_name, _value)
```

Then, the **translation rule** is defined roughly as follows (the one you'll
Then, the **lowering rule** is defined roughly as follows (the one you'll
find in the source code is a little more complicated since it supports both CPU
and GPU translation):

```python
# src/kepler_jax/kepler_jax.py
import numpy as np
from jax.interpreters import mlir
from jaxlib.mhlo_helpers import custom_call

def _kepler_cpu_translation(c, mean_anom, ecc):
# The inputs have "shapes" that provide both the shape and the dtype
mean_anom_shape = c.get_shape(mean_anom)
ecc_shape = c.get_shape(ecc)
def _kepler_lowering(ctx, mean_anom, ecc):

# Extract the dtype and shape
dtype = mean_anom_shape.element_type()
dims = mean_anom_shape.dimensions()
assert ecc_shape.element_type() == dtype
assert ecc_shape.dimensions() == dims
# Checking that input types and shape agree
assert mean_anom.type == ecc.type

# Extract the numpy type of the inputs
mean_anom_aval, ecc_aval = ctx.avals_in
np_dtype = np.dtype(mean_anom_aval.dtype)

# The inputs and outputs all have the same shape and memory layout
# so let's predefine this specification
dtype = mlir.ir.RankedTensorType(mean_anom.type)
dims = dtype.shape
layout = tuple(range(len(dims) - 1, -1, -1))

# The total size of the input is the product across dimensions
size = np.prod(dims).astype(np.int64)

# The inputs and outputs all have the same shape so let's predefine this
# specification
shape = xla_client.Shape.array_shape(
np.dtype(dtype), dims, tuple(range(len(dims) - 1, -1, -1))
)

# We dispatch a different call depending on the dtype
if dtype == np.float32:
op_name = b"cpu_kepler_f32"
elif dtype == np.float64:
op_name = b"cpu_kepler_f64"
if np_dtype == np.float32:
op_name = "cpu_kepler_f32"
elif np_dtype == np.float64:
op_name = "cpu_kepler_f64"
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
raise NotImplementedError(f"Unsupported dtype {np_dtype}")

# On the CPU, we pass the size of the data as a the first input
# argument
return xla_client.ops.CustomCallWithLayout(
c,
return custom_call(
op_name,
# Output types
out_types=[dtype, dtype],
# The inputs:
operands=(xla_client.ops.ConstantLiteral(c, size), mean_anom, ecc),
# The input shapes:
operand_shapes_with_layout=(
xla_client.Shape.array_shape(np.dtype(np.int64), (), ()),
shape,
shape,
),
# The output shapes:
shape_with_layout=xla_client.Shape.tuple_shape((shape, shape)),
)

xla.backend_specific_translations["cpu"][_kepler_prim] = _kepler_cpu_translation
operands=[mlir.ir_constant(size), mean_anom, ecc],
# Layout specification:
operand_layouts=[(), layout, layout],
result_layouts=[layout, layout]
)

mlir.register_lowering(
_kepler_prim,
_kepler_lowering,
platform="cpu")
```

There appears to be a lot going on here, but most of it is just typechecking.
The main meat of it is the `CustomCallWithLayout` function which, as far as I
can tell, isn't documented anywhere. Here's a summary of its arguments, as best
as I can tell:

- The first argument is the XLA builder that you were passed when your
translation rule was called.
There appears to be a lot going on here, but most of it is just type checking.
The main meat of it is the `custom_call` function which is a thin convenience
wrapper around the `mhlo.CustomCallOp` (documented
[here](https://www.tensorflow.org/mlir/hlo_ops#mhlocustom_call_mlirmhlocustomcallop)).
Here's a summary of its arguments:

- The second argument is the name (as `bytes`!) that you gave your `PyCapsule`
- The first argument is the name that you gave your `PyCapsule`
in the `registrations` dictionary in `lib/cpu_ops.cc`. You can check what
names your capsules had by looking at `cpu_ops.registrations().keys()`.

- Then, the following arguments give the input arguments, and the "shapes" of
the input and output arrays. In this context, a "shape" is specified by a data
type, a tuple defining the size of each dimension (what I would normally call
the shape), and a tuple defining the dimension order. In this case, we're
requiring that all of our inputs and outputs are of the same "shape".
- Then, the two following arguments give the "type" of the outputs, and
specify the input arguments (operands). In this context, a "type" is
specified by a data type defining the size of each dimension (what I
would normally call the shape), and the type of the array (e.g. float32).
In this case, both our outputs have the same type/shape.

- Finally, with the last two arguments, we specify the memory layout
of both input and output buffers.

It's worth remembering that we're expecting the first argument to our function
to be the size of the arrays, and you'll see that that is included as a
`ConstantLiteral` parameter (explicitly cast to `int64`).
`mlir.ir_constant` parameter.

I'm not going to talk about the **JVP rule** here since it's quite problem
specific, but I've tried to comment the code reasonably thoroughly so check out
Expand Down Expand Up @@ -631,15 +629,17 @@ from kepler_jax import gpu_ops
for _name, _value in gpu_ops.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="gpu")

def _kepler_gpu_translation(c, mean_anom, ecc):
def _kepler_lowering_gpu(ctx, mean_anom, ecc):
# Most of this function is the same as the CPU version above...

# ...

# The name of the op is now prefaced with 'gpu' (our choice, see lib/gpu_ops.cc,
# not a requirement)
if dtype == np.float32:
op_name = b"gpu_kepler_f32"
elif dtype == np.float64:
op_name = b"gpu_kepler_f64"
if np_dtype == np.float32:
op_name = "gpu_kepler_f32"
elif np_dtype == np.float64:
op_name = "gpu_kepler_f64"
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")

Expand All @@ -648,16 +648,23 @@ def _kepler_gpu_translation(c, mean_anom, ecc):

# The syntax is *almost* the same as the CPU version, but we need to pass the
# size using 'opaque' rather than as an input
return xla_client.ops.CustomCallWithLayout(
c,
return custom_call(
op_name,
operands=(mean_anom, ecc),
operand_shapes_with_layout=(shape, shape),
shape_with_layout=xla_client.Shape.tuple_shape((shape, shape)),
opaque=opaque,
# Output types
out_types=[dtype, dtype],
# The inputs:
operands=[mean_anom, ecc],
# Layout specification:
operand_layouts=[layout, layout],
result_layouts=[layout, layout],
# GPU-specific additional data for the kernel
backend_config=opaque
)

xla.backend_specific_translations["gpu"][_kepler_prim] = _kepler_gpu_translation
mlir.register_lowering(
_kepler_prim,
_kepler_lowering_gpu,
platform="gpu")
```

Otherwise, everything else from our CPU implementation doesn't need to change.
Expand Down
103 changes: 50 additions & 53 deletions src/kepler_jax/kepler_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from functools import partial

import numpy as np
from jax import numpy as jnp
from jax.lib import xla_client
from jax import core, dtypes, lax
from jax.interpreters import ad, batching, xla
from jax import numpy as jnp
from jax.abstract_arrays import ShapedArray
from jax.interpreters import ad, batching, mlir, xla
from jax.lib import xla_client
from jaxlib.mhlo_helpers import custom_call

# Register the CPU XLA custom calls
from . import cpu_ops
Expand All @@ -26,11 +27,10 @@
for _name, _value in gpu_ops.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="gpu")

xops = xla_client.ops


# This function exposes the primitive to user code and this is the only
# public-facing function in this module


def kepler(mean_anom, ecc):
# We're going to apply array broadcasting here since the logic of our op
# is much simpler if we require the inputs to all have the same shapes
Expand All @@ -56,73 +56,71 @@ def _kepler_abstract(mean_anom, ecc):
return (ShapedArray(shape, dtype), ShapedArray(shape, dtype))


# We also need a translation rule to convert the function into an XLA op. In
# our case this is the custom XLA op that we've written. We're wrapping two
# translation rules into one here: one for the CPU and one for the GPU
def _kepler_translation(c, mean_anom, ecc, *, platform="cpu"):
# The inputs have "shapes" that provide both the shape and the dtype
mean_anom_shape = c.get_shape(mean_anom)
ecc_shape = c.get_shape(ecc)
# We also need a lowering rule to provide an MLIR "lowering" of out primitive.
# This provides a mechanism for exposing our custom C++ and/or CUDA interfaces
# to the JAX XLA backend. We're wrapping two translation rules into one here:
# one for the CPU and one for the GPU
def _kepler_lowering(ctx, mean_anom, ecc, *, platform="cpu"):

# Checking that input types and shape agree
assert mean_anom.type == ecc.type

# Extract the dtype and shape
dtype = mean_anom_shape.element_type()
dims = mean_anom_shape.dimensions()
assert ecc_shape.element_type() == dtype
assert ecc_shape.dimensions() == dims
# Extract the numpy type of the inputs
mean_anom_aval, _ = ctx.avals_in
np_dtype = np.dtype(mean_anom_aval.dtype)

# The inputs and outputs all have the same shape and memory layout
# so let's predefine this specification
dtype = mlir.ir.RankedTensorType(mean_anom.type)
dims = dtype.shape
layout = tuple(range(len(dims) - 1, -1, -1))

# The total size of the input is the product across dimensions
size = np.prod(dims).astype(np.int64)

# The inputs and outputs all have the same shape so let's predefine this
# specification
shape = xla_client.Shape.array_shape(
np.dtype(dtype), dims, tuple(range(len(dims) - 1, -1, -1))
)

# We dispatch a different call depending on the dtype
if dtype == np.float32:
op_name = platform.encode() + b"_kepler_f32"
elif dtype == np.float64:
op_name = platform.encode() + b"_kepler_f64"
if np_dtype == np.float32:
op_name = platform + "_kepler_f32"
elif np_dtype == np.float64:
op_name = platform + "_kepler_f64"
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
raise NotImplementedError(f"Unsupported dtype {np_dtype}")

# And then the following is what changes between the GPU and CPU
if platform == "cpu":
# On the CPU, we pass the size of the data as a the first input
# argument
return xops.CustomCallWithLayout(
c,
return custom_call(
op_name,
# Output types
out_types=[dtype, dtype],
# The inputs:
operands=(xops.ConstantLiteral(c, size), mean_anom, ecc),
# The input shapes:
operand_shapes_with_layout=(
xla_client.Shape.array_shape(np.dtype(np.int64), (), ()),
shape,
shape,
),
# The output shapes:
shape_with_layout=xla_client.Shape.tuple_shape((shape, shape)),
operands=[mlir.ir_constant(size), mean_anom, ecc],
# Layout specification:
operand_layouts=[(), layout, layout],
result_layouts=[layout, layout]
)

elif platform == "gpu":
if gpu_ops is None:
raise ValueError(
"The 'kepler_jax' module was not compiled with CUDA support"
)

# On the GPU, we do things a little differently and encapsulate the
# dimension using the 'opaque' parameter
opaque = gpu_ops.build_kepler_descriptor(size)

return xops.CustomCallWithLayout(
c,
return custom_call(
op_name,
operands=(mean_anom, ecc),
operand_shapes_with_layout=(shape, shape),
shape_with_layout=xla_client.Shape.tuple_shape((shape, shape)),
opaque=opaque,
# Output types
out_types=[dtype, dtype],
# The inputs:
operands=[mean_anom, ecc],
# Layout specification:
operand_layouts=[layout, layout],
result_layouts=[layout, layout],
# GPU specific additional data
backend_config=opaque
)

raise ValueError(
Expand Down Expand Up @@ -188,12 +186,11 @@ def _kepler_batch(args, axes):
_kepler_prim.def_abstract_eval(_kepler_abstract)

# Connect the XLA translation rules for JIT compilation
xla.backend_specific_translations["cpu"][_kepler_prim] = partial(
_kepler_translation, platform="cpu"
)
xla.backend_specific_translations["gpu"][_kepler_prim] = partial(
_kepler_translation, platform="gpu"
)
for platform in ["cpu", "gpu"]:
mlir.register_lowering(
_kepler_prim,
partial(_kepler_lowering, platform=platform),
platform=platform)

# Connect the JVP and batching rules
ad.primitive_jvps[_kepler_prim] = _kepler_jvp
Expand Down