forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgpu_prng.py
87 lines (70 loc) · 2.68 KB
/
gpu_prng.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# Copyright 2019 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from functools import partial
import itertools
import operator
from typing import Optional, Union
import jaxlib.mlir.ir as ir
from jaxlib import xla_client
from .hlo_helpers import custom_call
try:
from .cuda import _prng as _cuda_prng
for _name, _value in _cuda_prng.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
_cuda_prng = None
try:
from .rocm import _prng as _hip_prng
for _name, _value in _hip_prng.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
except ImportError:
_hip_prng = None
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
def _threefry2x32_lowering(prng, platform, keys, data,
length: Optional[Union[int, ir.Value]] = None):
"""ThreeFry2x32 kernel for GPU."""
assert len(keys) == 2, keys
assert len(data) == 2, data
assert (ir.RankedTensorType(keys[0].type).element_type ==
ir.IntegerType.get_unsigned(32)), keys[0].type
typ = keys[0].type
dims = ir.RankedTensorType(typ).shape
for x in itertools.chain(keys, data):
assert x.type == typ, (x.type, typ)
ndims = len(dims)
layout = tuple(range(ndims - 1, -1, -1))
operand_layouts = [layout] * 4
operands = [keys[0], keys[1], data[0], data[1]]
if length is None:
length = _prod(dims)
if isinstance(length, int):
opaque = prng.threefry2x32_descriptor(length)
else:
opaque = prng.threefry2x32_descriptor(-1)
assert (ir.RankedTensorType(length.type).element_type ==
ir.IntegerType.get_signless(64)), length
assert (ir.RankedTensorType(length.type).shape ==
[1]), (length, ir.RankedTensorType(length.type).shape)
operands.append(length)
operand_layouts.append((0,))
return custom_call(
f"{platform}_threefry2x32",
[typ, typ],
operands,
backend_config=opaque,
operand_layouts=operand_layouts,
result_layouts=[layout] * 2)
cuda_threefry2x32 = partial(_threefry2x32_lowering, _cuda_prng, "cu")
rocm_threefry2x32 = partial(_threefry2x32_lowering, _hip_prng, "hip")