-
Notifications
You must be signed in to change notification settings - Fork 2.8k
/
lax.py
5036 lines (4271 loc) · 202 KB
/
lax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import builtins
from collections.abc import Sequence
import enum
import functools
from functools import partial
import itertools
import math
import operator
from typing import Any, Callable, TypeVar, Union, cast as type_cast, overload
import warnings
import numpy as np
import jax
from jax import tree_util
from jax.tree_util import tree_map
from jax._src import ad_util
from jax._src import api
from jax._src import api_util
from jax._src import array
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
from jax._src import shard_alike
from jax._src import linear_util as lu
from jax._src import pretty_printer as pp
from jax._src import source_info_util
from jax._src import util
from jax._src.abstract_arrays import array_types
from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray,
raise_to_shaped, abstract_token, canonicalize_shape)
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
from jax._src.interpreters.batching import RaggedAxis
from jax._src.lax import slicing
from jax._src.lax.utils import (
_input_dtype, dtype_to_string, standard_abstract_eval,
standard_multi_result_abstract_eval, standard_named_shape_rule,
standard_primitive)
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.sharding_impls import PmapSharding
from jax._src.typing import Array, ArrayLike, DuckTypedArray, DTypeLike, Shape
from jax._src.util import (cache, safe_zip, safe_map, canonicalize_axis,
split_list, NumpyComplexWarning)
xb = xla_bridge
xc = xla_client
xops = xla_client.ops
xe = xla_client._xla
_max = builtins.max
_min = builtins.min
_reduce = functools.reduce
T = TypeVar("T")
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
def _clip_int_to_valid_range(val: int, dtype) -> int:
info = np.iinfo(dtype)
return builtins.max(info.min, builtins.min(int(val), info.max))
def _validate_shapes(shapes: Sequence[Shape]):
def _check_static_shape(shape: Shape):
checked = canonicalize_shape(shape)
if not all(idx >= 0 for idx in checked):
msg = f"Only non-negative indices are allowed when broadcasting" \
f" static shapes, but got shape {shape!r}."
raise TypeError(msg)
assert shapes
if config.dynamic_shapes.value:
# pass dynamic shapes through unchecked
return
else:
map(_check_static_shape, shapes)
def _try_broadcast_shapes(
shapes: Sequence[tuple[int, ...]]) -> tuple[int, ...] | None:
if len(shapes) == 1: return shapes[0]
ranks = {len(shape) for shape in shapes}
if len(ranks) > 1: return None # must have consistent rank
rank = ranks.pop()
if not rank: return () # scalar case
result_shape = []
for ds in unsafe_zip(*shapes):
if all(core.same_referent(d, ds[0]) for d in ds[1:]):
# if all axes are identical objects, the resulting size is the object
result_shape.append(ds[0])
else:
# if all dims are equal (or 1), the result is the non-1 size (or 1)
non_1s = [d for d in ds if not core.definitely_equal(d, 1)]
if not non_1s:
result_shape.append(1)
elif all(core.definitely_equal(non_1s[0], d) for d in non_1s[1:]):
result_shape.append(non_1s[0])
else:
return None
return tuple(result_shape)
def asarray(x: ArrayLike) -> Array:
"""Lightweight conversion of ArrayLike input to Array output."""
if isinstance(x, Array):
return x
if isinstance(x, np.ndarray) or np.isscalar(x):
# Call device_put_impl directly to avoid binding the primitive.
return dispatch._device_put_impl(x)
else:
raise TypeError(f"asarray: expected ArrayLike, got {x} of type {type(x)}.")
@overload
def broadcast_shapes(*shapes: tuple[int, ...]) -> tuple[int, ...]: ...
@overload
def broadcast_shapes(*shapes: tuple[int | core.Tracer, ...]
) -> tuple[int | core.Tracer, ...]: ...
def broadcast_shapes(*shapes):
"""Returns the shape that results from NumPy broadcasting of `shapes`."""
# NOTE: We have both cached and uncached versions to handle Tracers in shapes.
try:
return _broadcast_shapes_cached(*shapes)
except:
return _broadcast_shapes_uncached(*shapes)
@cache()
def _broadcast_shapes_cached(*shapes: tuple[int, ...]) -> tuple[int, ...]:
return _broadcast_shapes_uncached(*shapes)
def _broadcast_shapes_uncached(*shapes):
_validate_shapes(shapes)
fst, *rst = shapes
if not rst: return fst
# First check if we need only rank promotion (and not singleton-broadcasting).
try: return _reduce(_broadcast_ranks, rst, fst)
except ValueError: pass
# Next try singleton-broadcasting, padding out ranks using singletons.
ndim = _max(len(shape) for shape in shapes)
shape_list = [(1,) * (ndim - len(shape)) + shape for shape in shapes]
result_shape = _try_broadcast_shapes(shape_list)
if result_shape is None:
raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
return result_shape
def _broadcast_ranks(s1, s2):
if len(s1) > len(s2):
s1, s2 = s2, s1
assert len(s1) <= len(s2)
s1_ = s2[len(s2) - len(s1):]
if core.definitely_equal_shape(s1_, s1): return s2
else: raise ValueError
def _identity(x): return x
def _extract_tracers_dyn_shape(
shape: Sequence[int | core.Tracer]
) -> tuple[list[core.Tracer], list[int | None]]:
# Given a sequence representing a shape, pull out Tracers, replacing with None
if config.dynamic_shapes.value:
# We must gate this behavior under a flag because otherwise the errors
# raised are different (and have worse source provenance information).
dyn_shape = [d for d in shape if isinstance(d, core.Tracer)]
static_shape = [None if isinstance(d, core.Tracer) else d for d in shape]
return dyn_shape, static_shape
else:
return [], list(shape) # type: ignore
def _merge_dyn_shape(
static_shape: Sequence[int | None],
dyn_shape: Sequence[Any],
) -> tuple[int | mlir.Value | core.Tracer, ...]:
# Replace Nones in static_shape with elements of dyn_shape, in order
dyn_shape_it = iter(dyn_shape)
shape = tuple(next(dyn_shape_it) if d is None else d for d in static_shape)
assert next(dyn_shape_it, None) is None
return shape
def _dyn_shape_staging_rule(trace, prim, out_aval, *args, **params):
source_info = source_info_util.current()
out_tracer = pe.DynamicJaxprTracer(trace, out_aval, source_info)
eqn = pe.new_jaxpr_eqn([trace.getvar(x) for x in args],
[trace.makevar(out_tracer)],
prim, params, core.no_effects, source_info)
trace.frame.add_eqn(eqn)
return out_tracer
### traceables
def neg(x: ArrayLike) -> Array:
r"""Elementwise negation: :math:`-x`."""
return neg_p.bind(x)
def sign(x: ArrayLike) -> Array:
r"""Elementwise sign.
For floating-point inputs, returns
:math:`\mathrm{sign}(x) = \begin{cases}
-1 & x < 0\\
-0 & x = -0\\
\mathit{NaN} & x = \mathit{NaN}\\
+0 & x = +0\\
1 & x > 0
\end{cases}`
For signed integer inputs, returns
:math:`\mathrm{sign}(x) = \begin{cases}
-1 & x < 0\\
0 & x = 0\\
1 & x > 0
\end{cases}`
For complex inputs, returns the complex phase, i.e.
:math:`\mathrm{sign}(x) = \frac{x}{|x|}`.
"""
return sign_p.bind(x)
def nextafter(x1: ArrayLike, x2: ArrayLike) -> Array:
r"""Returns the next representable value after `x1` in the direction of `x2`.
Note that in some environments flush-denormal-to-zero semantics is used.
This means that, around zero, this function returns strictly non-zero
values which appear as zero in any operations. Consider this example::
>>> jnp.nextafter(0, 1) # denormal numbers are representable
Array(1.e-45, dtype=float32, weak_type=True)
>>> jnp.nextafter(0, 1) * 1 # but are flushed to zero
Array(0., dtype=float32, weak_type=True)
For the smallest usable (i.e. normal) float, use ``tiny`` of ``jnp.finfo``.
"""
return nextafter_p.bind(x1, x2)
def floor(x: ArrayLike) -> Array:
r"""Elementwise floor: :math:`\left\lfloor x \right\rfloor`."""
return floor_p.bind(x)
def ceil(x: ArrayLike) -> Array:
r"""Elementwise ceiling: :math:`\left\lceil x \right\rceil`."""
return ceil_p.bind(x)
class RoundingMethod(enum.IntEnum):
AWAY_FROM_ZERO = 0
TO_NEAREST_EVEN = 1
def round(x: ArrayLike,
rounding_method: RoundingMethod = RoundingMethod.AWAY_FROM_ZERO
) -> Array:
r"""Elementwise round.
Rounds values to the nearest integer.
Args:
x: an array or scalar value to round.
rounding_method: the method to use when rounding halfway values
(e.g., `0.5`). See ``lax.RoundingMethod`` for the list of possible
values.
Returns:
An array containing the elementwise rounding of x.
"""
rounding_method = RoundingMethod(rounding_method)
return round_p.bind(x, rounding_method=rounding_method)
def is_finite(x: ArrayLike) -> Array:
r"""Elementwise :math:`\mathrm{isfinite}`.
For each element x returns `True` if and only if x is not :math:`\pm\infty` or
:math:`\mathit{NaN}`.
"""
return is_finite_p.bind(x)
def exp(x: ArrayLike) -> Array:
r"""Elementwise exponential: :math:`e^x`."""
return exp_p.bind(x)
def exp2(x: ArrayLike) -> Array:
r"""Elementwise base-2 exponential: :math:`2^x`."""
return exp2_p.bind(x)
def expm1(x: ArrayLike) -> Array:
r"""Elementwise :math:`e^{x} - 1`."""
return expm1_p.bind(x)
def log(x: ArrayLike) -> Array:
r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`."""
return log_p.bind(x)
def log1p(x: ArrayLike) -> Array:
r"""Elementwise :math:`\mathrm{log}(1 + x)`."""
return log1p_p.bind(x)
def tanh(x: ArrayLike) -> Array:
r"""Elementwise hyperbolic tangent: :math:`\mathrm{tanh}(x)`."""
return tanh_p.bind(x)
def logistic(x: ArrayLike) -> Array:
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`."""
return logistic_p.bind(x)
def sin(x: ArrayLike) -> Array:
r"""Elementwise sine: :math:`\mathrm{sin}(x)`."""
return sin_p.bind(x)
def cos(x: ArrayLike) -> Array:
r"""Elementwise cosine: :math:`\mathrm{cos}(x)`."""
return cos_p.bind(x)
def atan2(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise arc tangent of two variables:
:math:`\mathrm{atan}({x \over y})`."""
return atan2_p.bind(x, y)
def real(x: ArrayLike) -> Array:
r"""Elementwise extract real part: :math:`\mathrm{Re}(x)`.
Returns the real part of a complex number.
"""
return real_p.bind(x)
def imag(x: ArrayLike) -> Array:
r"""Elementwise extract imaginary part: :math:`\mathrm{Im}(x)`.
Returns the imaginary part of a complex number.
"""
return imag_p.bind(x)
def complex(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise make complex number: :math:`x + jy`.
Builds a complex number from real and imaginary parts.
"""
return complex_p.bind(x, y)
def conj(x: ArrayLike) -> Array:
r"""Elementwise complex conjugate function: :math:`\overline{x}`."""
# TODO(mattjj): remove input_dtype, not needed anymore
return conj_p.bind(x, input_dtype=_dtype(x))
def abs(x: ArrayLike) -> Array:
r"""Elementwise absolute value: :math:`|x|`."""
return abs_p.bind(x)
def pow(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise power: :math:`x^y`."""
return pow_p.bind(x, y)
def integer_pow(x: ArrayLike, y: int) -> Array:
r"""Elementwise power: :math:`x^y`, where :math:`y` is a fixed integer."""
return integer_pow_p.bind(x, y=y)
def sqrt(x: ArrayLike) -> Array:
r"""Elementwise square root: :math:`\sqrt{x}`."""
return sqrt_p.bind(x)
def rsqrt(x: ArrayLike) -> Array:
r"""Elementwise reciprocal square root: :math:`1 \over \sqrt{x}`."""
return rsqrt_p.bind(x)
def cbrt(x: ArrayLike) -> Array:
r"""Elementwise cube root: :math:`\sqrt[3]{x}`."""
return cbrt_p.bind(x)
def bitwise_not(x: ArrayLike) -> Array:
r"""Elementwise NOT: :math:`\neg x`."""
return not_p.bind(x)
def bitwise_and(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise AND: :math:`x \wedge y`."""
return and_p.bind(x, y)
def bitwise_or(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise OR: :math:`x \vee y`."""
return or_p.bind(x, y)
def bitwise_xor(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise exclusive OR: :math:`x \oplus y`."""
return xor_p.bind(x, y)
def population_count(x: ArrayLike) -> Array:
r"""Elementwise popcount, count the number of set bits in each element."""
return population_count_p.bind(x)
def clz(x: ArrayLike) -> Array:
r"""Elementwise count-leading-zeros."""
return clz_p.bind(x)
def add(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise addition: :math:`x + y`."""
return add_p.bind(x, y)
def sub(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise subtraction: :math:`x - y`."""
return sub_p.bind(x, y)
def mul(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise multiplication: :math:`x \times y`."""
return mul_p.bind(x, y)
def div(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise division: :math:`x \over y`.
Integer division overflow
(division by zero or signed division of INT_SMIN with -1)
produces an implementation defined value.
"""
return div_p.bind(x, y)
def rem(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise remainder: :math:`x \bmod y`.
The sign of the result is taken from the dividend,
and the absolute value of the result is always
less than the divisor's absolute value.
Integer division overflow
(remainder by zero or remainder of INT_SMIN with -1)
produces an implementation defined value.
"""
return rem_p.bind(x, y)
def max(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise maximum: :math:`\mathrm{max}(x, y)`
For complex numbers, uses a lexicographic comparison on the
`(real, imaginary)` pairs."""
return max_p.bind(x, y)
def min(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise minimum: :math:`\mathrm{min}(x, y)`
For complex numbers, uses a lexicographic comparison on the
`(real, imaginary)` pairs."""
return min_p.bind(x, y)
def shift_left(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise left shift: :math:`x \ll y`."""
return shift_left_p.bind(x, y)
def shift_right_arithmetic(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise arithmetic right shift: :math:`x \gg y`."""
return shift_right_arithmetic_p.bind(x, y)
def shift_right_logical(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise logical right shift: :math:`x \gg y`."""
return shift_right_logical_p.bind(x, y)
def eq(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise equals: :math:`x = y`."""
return eq_p.bind(x, y)
def ne(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise not-equals: :math:`x \neq y`."""
return ne_p.bind(x, y)
def ge(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise greater-than-or-equals: :math:`x \geq y`."""
return ge_p.bind(x, y)
def gt(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise greater-than: :math:`x > y`."""
return gt_p.bind(x, y)
def le(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise less-than-or-equals: :math:`x \leq y`."""
return le_p.bind(x, y)
def lt(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise less-than: :math:`x < y`."""
return lt_p.bind(x, y)
def convert_element_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
"""Elementwise cast.
Wraps XLA's `ConvertElementType
<https://www.tensorflow.org/xla/operation_semantics#convertelementtype>`_
operator, which performs an elementwise conversion from one type to another.
Similar to a C++ `static_cast`.
Args:
operand: an array or scalar value to be cast
new_dtype: a NumPy dtype representing the target type.
Returns:
An array with the same shape as `operand`, cast elementwise to `new_dtype`.
"""
return _convert_element_type(operand, new_dtype, weak_type=False)
def _convert_element_type(operand: ArrayLike, new_dtype: DTypeLike | None = None,
weak_type: bool = False):
if hasattr(operand, '__jax_array__'):
operand = operand.__jax_array__() # type: ignore
if (dtypes.issubdtype(new_dtype, dtypes.extended) or
dtypes.issubdtype(getattr(operand, 'dtype', None), dtypes.extended)):
return convert_element_type_p.bind(operand, new_dtype=new_dtype,
weak_type=bool(weak_type))
# Don't canonicalize old_dtype because x64 context might cause
# un-canonicalized operands to be passed in.
old_dtype = dtypes.dtype(operand, canonicalize=False)
old_weak_type = dtypes.is_weakly_typed(operand)
if new_dtype is None:
new_dtype = old_dtype
else:
new_dtype = np.dtype(new_dtype)
new_dtype = dtypes.dtype(new_dtype, canonicalize=True)
if (dtypes.issubdtype(old_dtype, np.complexfloating) and
not dtypes.issubdtype(new_dtype, np.complexfloating)):
msg = "Casting complex values to real discards the imaginary part"
warnings.warn(msg, NumpyComplexWarning, stacklevel=2)
# Python has big integers, but convert_element_type(2 ** 100, np.float32) need
# not be an error since the target dtype fits the value. Handle this case by
# converting to a NumPy array before calling bind. Without this step, we'd
# first canonicalize the input to a value of dtype int32 or int64, leading to
# an overflow error.
if type(operand) is int:
operand = np.asarray(operand).astype(new_dtype)
old_weak_type = False
if ((old_dtype, old_weak_type) == (new_dtype, weak_type) and
isinstance(operand, Array) and
not (isinstance(operand, core.Tracer) and
isinstance(core.get_aval(operand), core.ConcreteArray))):
return type_cast(Array, operand)
else:
return convert_element_type_p.bind(operand, new_dtype=new_dtype,
weak_type=bool(weak_type))
def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
"""Elementwise bitcast.
Wraps XLA's `BitcastConvertType
<https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype>`_
operator, which performs a bit cast from one type to another.
The output shape depends on the size of the input and output dtypes with
the following logic::
if new_dtype.itemsize == operand.dtype.itemsize:
output_shape = operand.shape
if new_dtype.itemsize < operand.dtype.itemsize:
output_shape = (*operand.shape, operand.dtype.itemsize // new_dtype.itemsize)
if new_dtype.itemsize > operand.dtype.itemsize:
assert operand.shape[-1] * operand.dtype.itemsize == new_dtype.itemsize
output_shape = operand.shape[:-1]
Args:
operand: an array or scalar value to be cast
new_dtype: the new type. Should be a NumPy type.
Returns:
An array of shape `output_shape` (see above) and type `new_dtype`,
constructed from the same bits as operand.
"""
new_dtype = dtypes.canonicalize_dtype(new_dtype)
return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype)
def clamp(min: ArrayLike, x: ArrayLike, max: ArrayLike) -> Array:
r"""Elementwise clamp.
Returns :math:`\mathrm{clamp}(x) = \begin{cases}
\mathit{min} & \text{if } x < \mathit{min},\\
\mathit{max} & \text{if } x > \mathit{max},\\
x & \text{otherwise}
\end{cases}`.
"""
return clamp_p.bind(min, x, max)
def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array:
"""Concatenates a sequence of arrays along `dimension`.
Wraps XLA's `Concatenate
<https://www.tensorflow.org/xla/operation_semantics#concatenate>`_
operator.
Args:
operands: a sequence of arrays to concatenate. The arrays must have equal
shapes, except in the `dimension` axis.
dimension: the dimension along which to concatenate the arrays.
Returns:
An array containing the concatenation.
"""
if len(operands) == 0:
raise ValueError("concatenate requires a non-empty sequences of arrays")
if len(operands) == 1:
op, = operands
if isinstance(op, Array):
return type_cast(Array, op)
return concatenate_p.bind(*operands, dimension=dimension)
class _enum_descriptor:
def __init__(self, val):
self.val = val
def __get__(self, _, owner):
return owner(self.val)
class Precision(xla_client.PrecisionConfig.Precision): # type: ignore
"""Precision enum for lax functions
The `precision` argument to JAX functions generally controls the tradeoff
between speed and accuracy for array computations on accelerator backends,
(i.e. TPU and GPU). Members are:
DEFAULT:
Fastest mode, but least accurate. Performs computations in bfloat16.
Aliases: ``'default'``, ``'fastest'``, ``'bfloat16'``.
HIGH:
Slower but more accurate. Performs float32 computations in 3 bfloat16
passes, or using tensorfloat32 where available. Aliases: ``'high'``,
``'bfloat16_3x'``, ``'tensorfloat32'``.
HIGHEST:
Slowest but most accurate. Performs computations in float32 or float64
as applicable. Aliases: ``'highest'``, ``'float32'``.
"""
# Wrap enum values with this class.
DEFAULT = _enum_descriptor('default')
HIGH = _enum_descriptor('high')
HIGHEST = _enum_descriptor('highest')
_strings = {
'highest': xla_client.PrecisionConfig.Precision.HIGHEST,
'float32': xla_client.PrecisionConfig.Precision.HIGHEST,
'high': xla_client.PrecisionConfig.Precision.HIGH,
'bfloat16_3x': xla_client.PrecisionConfig.Precision.HIGH,
'tensorfloat32': xla_client.PrecisionConfig.Precision.HIGH,
'default': xla_client.PrecisionConfig.Precision.DEFAULT,
'bfloat16': xla_client.PrecisionConfig.Precision.DEFAULT,
'fastest': xla_client.PrecisionConfig.Precision.DEFAULT,
None: xla_client.PrecisionConfig.Precision.DEFAULT,
}
def __init__(self, arg0):
arg0 = self._strings.get(arg0, arg0)
super().__init__(arg0)
def __str__(self) -> str:
return self.name
PrecisionType = Precision
PrecisionLike = Union[
str,
PrecisionType,
tuple[str, str],
tuple[PrecisionType, PrecisionType],
None,
]
def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None) -> Array:
"""Vector/vector, matrix/vector, and matrix/matrix multiplication.
Wraps XLA's `Dot
<https://www.tensorflow.org/xla/operation_semantics#dot>`_
operator.
For more general contraction, see the `dot_general` operator.
Args:
lhs: an array of dimension 1 or 2.
rhs: an array of dimension 1 or 2.
precision: Optional. Either ``None``, which means the default precision for
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
:class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and ``rhs``.
preferred_element_type: Optional. Either ``None``, which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.
Returns:
An array containing the product.
"""
if 1 <= lhs.ndim <= 2 and 1 <= rhs.ndim <= 2 and core.definitely_equal(lhs.shape[-1], rhs.shape[0]):
return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())),
precision=precision,
preferred_element_type=preferred_element_type)
else:
raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
lhs.shape, rhs.shape))
DotDimensionNumbers = tuple[tuple[Sequence[int], Sequence[int]],
tuple[Sequence[int], Sequence[int]]]
def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None) -> Array:
"""General dot product/contraction operator.
Wraps XLA's `DotGeneral
<https://www.tensorflow.org/xla/operation_semantics#dotgeneral>`_
operator.
The semantics of ``dot_general`` are complicated, but most users should not have to
use it directly. Instead, you can use higher-level functions like :func:`jax.numpy.dot`,
:func:`jax.numpy.matmul`, :func:`jax.numpy.tensordot`, :func:`jax.numpy.einsum`,
and others which will construct appropriate calls to ``dot_general`` under the hood.
If you really want to understand ``dot_general`` itself, we recommend reading XLA's
`DotGeneral <https://www.tensorflow.org/xla/operation_semantics#dotgeneral>`_
operator documentation.
Args:
lhs: an array
rhs: an array
dimension_numbers: a tuple of tuples of sequences of ints of the form
``((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))``
precision: Optional. Either ``None``, which means the default precision for
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
:class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and ``rhs``.
preferred_element_type: Optional. Either ``None``, which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.
Returns:
An array whose first dimensions are the (shared) batch dimensions, followed by
the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs``
non-contracting/non-batch dimensions.
"""
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
cdims = (api_util._ensure_index_tuple(lhs_contract),
api_util._ensure_index_tuple(rhs_contract))
bdims = (api_util._ensure_index_tuple(lhs_batch),
api_util._ensure_index_tuple(rhs_batch))
preferred_element_type = (
None if preferred_element_type is None else
dtypes.canonicalize_dtype(np.dtype(preferred_element_type)))
return dot_general_p.bind(lhs, rhs,
dimension_numbers=(cdims, bdims),
precision=canonicalize_precision(precision),
preferred_element_type=preferred_element_type)
def broadcast(operand: ArrayLike, sizes: Sequence[int]) -> Array:
"""Broadcasts an array, adding new leading dimensions
Args:
operand: an array
sizes: a sequence of integers, giving the sizes of new leading dimensions
to add to the front of the array.
Returns:
An array containing the result.
See Also:
jax.lax.broadcast_in_dim : add new dimensions at any location in the array shape.
"""
dims = tuple(range(len(sizes), len(sizes) + np.ndim(operand)))
return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims)
def broadcast_in_dim(operand: ArrayLike, shape: Shape,
broadcast_dimensions: Sequence[int]) -> Array:
"""Wraps XLA's `BroadcastInDim
<https://www.tensorflow.org/xla/operation_semantics#broadcastindim>`_
operator.
Args:
operand: an array
shape: the shape of the target array
broadcast_dimensions: to which dimension in the target shape each dimension
of the operand shape corresponds to. That is, dimension i of the operand
becomes dimension broadcast_dimensions[i] of the result.
Returns:
An array containing the result.
See Also:
jax.lax.broadcast : simpler interface to add new leading dimensions.
"""
if np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and isinstance(operand, Array):
return type_cast(Array, operand)
if config.dynamic_shapes.value:
# We must gate this behavior under a flag because otherwise the errors
# raised are different (and have worse source provenance information).
dyn_shape, static_shape = _extract_tracers_dyn_shape(shape)
else:
dyn_shape, static_shape = [], shape # type: ignore
return broadcast_in_dim_p.bind(
operand, *dyn_shape, shape=tuple(static_shape),
broadcast_dimensions=tuple(broadcast_dimensions))
def broadcast_to_rank(x: Array, rank: int) -> Array:
"""Adds leading dimensions of ``1`` to give ``x`` rank ``rank``."""
return broadcast(x, (1,) * (rank - x.ndim))
def reshape(operand: ArrayLike, new_sizes: Shape,
dimensions: Sequence[int] | None = None) -> Array:
"""Wraps XLA's `Reshape
<https://www.tensorflow.org/xla/operation_semantics#reshape>`_
operator.
For inserting/removing dimensions of size 1, prefer using ``lax.squeeze`` /
``lax.expand_dims``. These preserve information about axis identity that may
be useful for advanced transformation rules.
Args:
operand: array to be reshaped.
new_sizes: sequence of integers specifying the resulting shape. The size
of the final array must match the size of the input.
dimensions: optional sequence of integers specifying the permutation order of
the input shape. If specified, the length must match ``operand.shape``.
Returns:
out: reshaped array.
Examples:
Simple reshaping from one to two dimensions:
>>> x = jnp.arange(6)
>>> y = reshape(x, (2, 3))
>>> y
Array([[0, 1, 2],
[3, 4, 5]], dtype=int32)
Reshaping back to one dimension:
>>> reshape(y, (6,))
Array([0, 1, 2, 3, 4, 5], dtype=int32)
Reshaping to one dimension with permutation of dimensions:
>>> reshape(y, (6,), (1, 0))
Array([0, 3, 1, 4, 2, 5], dtype=int32)
"""
new_sizes = canonicalize_shape(new_sizes) # TODO
new_sizes = tuple(new_sizes)
same_shape = core.definitely_equal_shape(np.shape(operand), new_sizes)
if dimensions is None:
same_dims = True
dims = None
else:
dims = api_util._ensure_index_tuple(dimensions)
same_dims = tuple(dims) == tuple(range(np.ndim(operand)))
if np.shape(operand) and same_shape and same_dims and isinstance(operand, Array):
return type_cast(Array, operand)
else:
dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes)
return reshape_p.bind(
operand, *dyn_shape, new_sizes=tuple(static_new_sizes),
dimensions=None if dims is None or same_dims else dims)
def pad(operand: ArrayLike, padding_value: ArrayLike,
padding_config: Sequence[tuple[int, int, int]]) -> Array:
"""Applies low, high, and/or interior padding to an array.
Wraps XLA's `Pad
<https://www.tensorflow.org/xla/operation_semantics#pad>`_
operator.
Args:
operand: an array to be padded.
padding_value: the value to be inserted as padding. Must have the same dtype
as ``operand``.
padding_config: a sequence of ``(low, high, interior)`` tuples of integers,
giving the amount of low, high, and interior (dilation) padding to insert
in each dimension.
Returns:
The ``operand`` array with padding value ``padding_value`` inserted in each
dimension according to the ``padding_config``.
"""
return pad_p.bind(operand, padding_value, padding_config=tuple(padding_config))
def rev(operand: ArrayLike, dimensions: Sequence[int]) -> Array:
"""Wraps XLA's `Rev
<https://www.tensorflow.org/xla/operation_semantics#rev_reverse>`_
operator.
"""
return rev_p.bind(operand, dimensions=tuple(dimensions))
def select(pred: ArrayLike, on_true: ArrayLike, on_false: ArrayLike) -> Array:
"""Selects between two branches based on a boolean predicate.
Wraps XLA's `Select
<https://www.tensorflow.org/xla/operation_semantics#select>`_
operator.
In general :func:`~jax.lax.select` leads to evaluation of both branches, although
the compiler may elide computations if possible. For a similar function that
usually evaluates only a single branch, see :func:`~jax.lax.cond`.
Args:
pred: boolean array
on_true: array containing entries to return where ``pred`` is True. Must have
the same shape as ``pred``, and the same shape and dtype as ``on_false``.
on_false: array containing entries to return where ``pred`` is False. Must have
the same shape as ``pred``, and the same shape and dtype as ``on_true``.
Returns:
result: array with same shape and dtype as ``on_true`` and ``on_false``.
"""
# Caution! The select_n_p primitive has the *opposite* order of arguments to
# select(). This is because it implements `select_n`.
return select_n_p.bind(pred, on_false, on_true)
def select_n(which: ArrayLike, *cases: ArrayLike) -> Array:
"""Selects array values from multiple cases.
Generalizes XLA's `Select
<https://www.tensorflow.org/xla/operation_semantics#select>`_
operator. Unlike XLA's version, the operator is variadic and can select
from many cases using an integer `pred`.
Args:
which: determines which case should be returned. Must be an array containing
either a boolean or integer values. May either be a scalar or have
shape matching ``cases``. For each array element, the value of ``which``
determines which of ``cases`` is taken. ``which`` must be in the range
``[0 .. len(cases))``; for values outside that range the behavior is
implementation-defined.
*cases: a non-empty list of array cases. All must have equal dtypes and
equal shapes.
Returns:
An array with shape and dtype equal to the cases, whose values are chosen
according to ``which``.
"""
if len(cases) == 0:
raise ValueError("select_n() must have at least one case")
return select_n_p.bind(which, *cases)
def transpose(operand: ArrayLike,
permutation: Sequence[int] | np.ndarray) -> Array:
"""Wraps XLA's `Transpose
<https://www.tensorflow.org/xla/operation_semantics#transpose>`_
operator.
"""
permutation = tuple(operator.index(d) for d in permutation)
if permutation == tuple(range(np.ndim(operand))) and isinstance(operand, Array):
return type_cast(Array, operand)
else:
return transpose_p.bind(operand, permutation=permutation)
def argmin(operand: ArrayLike, axis: int,
index_dtype: DTypeLike) -> Array:
"""Computes the index of the minimum element along ``axis``."""
return argmin_p.bind(operand, axes=(axis,),
index_dtype=dtypes.canonicalize_dtype(index_dtype))
def argmax(operand: ArrayLike, axis: int,
index_dtype: DTypeLike) -> Array:
"""Computes the index of the maximum element along ``axis``."""
return argmax_p.bind(operand, axes=(axis,),
index_dtype=dtypes.canonicalize_dtype(index_dtype))
def reduce(operands: Any,
init_values: Any,
computation: Callable[[Any, Any], Any],
dimensions: Sequence[int]) -> Any:
"""Wraps XLA's `Reduce
<https://www.tensorflow.org/xla/operation_semantics#reduce>`_
operator.
``init_values`` and ``computation`` together must form a `monoid
<https://en.wikipedia.org/wiki/Monoid>`_
for correctness. That is ``init_values`` must be an identity of
``computation``, and ``computation`` must be associative. XLA may exploit both
of these properties during code generation; if either is violated the result
is undefined.
"""
flat_operands, operand_tree = tree_util.tree_flatten(operands)
flat_init_values, init_value_tree = tree_util.tree_flatten(init_values)
if operand_tree != init_value_tree:
raise ValueError('Operands must have the same tree structure as init_values:'