-
Notifications
You must be signed in to change notification settings - Fork 2.8k
/
core.py
2965 lines (2458 loc) · 102 KB
/
core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections import Counter, defaultdict, deque, namedtuple
from collections.abc import (Callable, Collection, Hashable, Iterable, Iterator,
Sequence, MutableSet, MutableMapping)
from contextlib import contextmanager, ExitStack
from dataclasses import dataclass
import functools
from functools import partial, total_ordering
import gc
import inspect
import itertools as it
import math
import operator
import threading
import types
from typing import (Any, ClassVar, Generic, NamedTuple, TypeVar,
overload, Union)
import warnings
from weakref import ref
import numpy as np
from jax._src import dtypes
from jax._src import config
from jax._src import effects
from jax._src import compute_on
from jax._src.errors import (
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
TracerIntegerConversionError, UnexpectedTracerError)
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src.util import (safe_zip, safe_map, curry, tuple_insert,
tuple_delete,
HashableFunction, HashableWrapper, weakref_lru_cache,
partition_list, StrictABCMeta)
import jax._src.pretty_printer as pp
from jax._src.lib import jax_jit
from jax._src import traceback_util
from jax._src.typing import Array, DimSize, Shape
from jax._src import typing
from jax._src import xla_metadata as xla_metadata_lib
traceback_util.register_exclusion(__file__)
zip, unsafe_zip = safe_zip, zip
map, unsafe_map = safe_map, map
_TRACER_ERROR_NUM_TRACEBACK_FRAMES = config.int_flag(
'jax_tracer_error_num_traceback_frames',
config.int_env('JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES', 5),
help='Set the number of stack frames in JAX tracer error messages.'
)
# -------------------- jaxprs --------------------
Effect = effects.Effect
Effects = effects.Effects
EffectTypeSet = effects.EffectTypeSet
no_effects: Effects = effects.no_effects
class JaxprDebugInfo(NamedTuple):
traced_for: str # e.g. 'jit', 'scan', etc
func_src_info: str | None # e.g. f'{fun.__name__} at {filename}:{lineno}'
arg_names: tuple[str | None, ...] # e.g. ('args[0]', ... )
result_paths: tuple[str, ...] # e.g. ('[0]', '[1]', ...)
class Jaxpr:
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns',
'_effects', '_debug_info']
_constvars: list[Var]
_invars: list[Var]
_outvars: list[Atom]
_eqns: list[JaxprEqn]
_effects: Effects
_debug_info: JaxprDebugInfo | None
@property
def constvars(self) -> list[Var]:
return self._constvars
@property
def invars(self) -> list[Var]:
return self._invars
@property
def outvars(self) -> list[Atom]:
return self._outvars
@property
def eqns(self) -> list[JaxprEqn]:
return self._eqns
@property
def effects(self) -> Effects:
return self._effects
@property
def debug_info(self) -> JaxprDebugInfo | None:
return self._debug_info
def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn],
effects: Effects = no_effects,
debug_info: JaxprDebugInfo | None = None):
"""
Args:
constvars: list of variables introduced for constants. Array constants are
replaced with such variables while scalar constants are kept inline.
invars: list of input variables. Together, `constvars` and `invars` are
the inputs to the Jaxpr.
outvars: list of output atoms.
eqns: list of equations.
effects: set of effects. The effects on a jaxpr are a superset of the
union of the effects for each equation.
debug_info: optional JaxprDebugInfo.
"""
self._constvars = list(constvars)
self._invars = list(invars)
self._outvars = list(outvars)
self._eqns = list(eqns)
self._effects = effects
self._debug_info = debug_info
assert (not debug_info or len(debug_info.arg_names) == len(invars) and
len(debug_info.result_paths) == len(outvars))
def __str__(self):
return str(self.pretty_print())
__repr__ = __str__
def pretty_print(self, *, source_info=False, print_shapes=True,
custom_pp_eqn_rules=True, name_stack=False,
print_effects: bool = False, **kwargs):
doc = pp_toplevel_jaxpr(
self, source_info=source_info, print_shapes=print_shapes,
custom_pp_eqn_rules=custom_pp_eqn_rules, name_stack=name_stack,
print_effects=print_effects)
return doc.format(**kwargs)
def _repr_pretty_(self, p, cycle):
return p.text(self.pretty_print(use_color=True))
def replace(self, **kwargs):
jaxpr = Jaxpr(
constvars=kwargs.pop("constvars", self.constvars),
invars=kwargs.pop("invars", self.invars),
outvars=kwargs.pop("outvars", self.outvars),
eqns=kwargs.pop("eqns", self.eqns),
effects=kwargs.pop("effects", self.effects),
debug_info=kwargs.pop("debug_info", self.debug_info),
)
if kwargs:
raise ValueError(f"Unknown keyword arguments: {kwargs}")
return jaxpr
def join_effects(*effects: Effects) -> Effects:
return set().union(*effects) if effects else no_effects
def jaxprs_in_params(params) -> Iterator[Jaxpr]:
for val in params.values():
vals = val if isinstance(val, tuple) else (val,)
for v in vals:
if isinstance(v, Jaxpr):
yield v
elif isinstance(v, ClosedJaxpr):
yield v.jaxpr
def subjaxprs(jaxpr: Jaxpr) -> Iterator[Jaxpr]:
"""Generator for all subjaxprs found in the params of jaxpr.eqns.
Does not descend recursively into the found subjaxprs.
"""
for eqn in jaxpr.eqns:
yield from jaxprs_in_params(eqn.params)
class ClosedJaxpr:
__slots__ = ['__weakref__', '_jaxpr', '_consts']
_jaxpr: Jaxpr
_consts: list[Any]
jaxpr = property(lambda self: self._jaxpr)
consts = property(lambda self: self._consts)
def __init__(self, jaxpr: Jaxpr, consts: Sequence):
assert len(consts) == len(jaxpr.constvars)
# assert not any(isinstance(c, Tracer) for c in consts) # TODO(mattjj): enable
self._jaxpr = jaxpr
self._consts = list(consts)
@property
def in_avals(self):
return [v.aval for v in self.jaxpr.invars]
@property
def out_avals(self):
return [v.aval for v in self.jaxpr.outvars]
@property
def literals(self):
return self.consts # backwards compatible alias
@property
def eqns(self):
return self.jaxpr.eqns
@property
def effects(self) -> Effects:
return self.jaxpr.effects
def map_jaxpr(self, f):
return ClosedJaxpr(f(self.jaxpr), self.consts)
def replace(self, *, jaxpr=None, consts=None):
jaxpr = self.jaxpr if jaxpr is None else jaxpr
consts = self.consts if consts is None else consts
return ClosedJaxpr(jaxpr, consts)
def __str__(self): return str(self.jaxpr)
def __repr__(self): return repr(self.jaxpr)
def pretty_print(self, *, source_info=False, print_shapes=True,
name_stack=False, custom_pp_eqn_rules=True,
print_effects=False, **kwargs):
return self.jaxpr.pretty_print(
source_info=source_info,
print_shapes=print_shapes,
name_stack=name_stack,
custom_pp_eqn_rules=custom_pp_eqn_rules,
print_effects=print_effects,
**kwargs)
def _repr_pretty_(self, p, cycle):
return p.text(self.pretty_print(use_color=True))
@curry
def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args):
# TODO(dougalm): remove this hack when we add contexts to jaxpr.
# debug_nans is sometimes disabled locally at the traceable level by ops that
# work with nans internally, like jnp.var. The right thing to do is to add
# contexts to our jaxpr representation so that we can capture these local
# context modifications. In the meantime, disabling the checks when we
# round-trip prevents those ops producing spurious errors.
with config.debug_nans(False):
return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
class JaxprEqnContext:
def __init__(self, compute_type: str | None, threefry_partitionable: bool,
xla_metadata=None):
self.compute_type = compute_type
self.threefry_partitionable = threefry_partitionable
self.xla_metadata = xla_metadata
self._managers = [
(compute_on.extend_compute_type, self.compute_type),
(config.threefry_partitionable.__call__, self.threefry_partitionable),
(xla_metadata_lib.set_xla_metadata, self.xla_metadata),
]
@property
@contextmanager
def manager(self):
with ExitStack() as stack:
for manager, val in self._managers:
stack.enter_context(manager(val))
yield
def __repr__(self):
return (
f"JaxprEqnContext(compute_type={self.compute_type}, "
f"threefry_partitionable={self.threefry_partitionable}, "
f"xla_metadata={self.xla_metadata})"
)
class JaxprEqn:
invars: list[Atom]
outvars: list[Var]
primitive: Primitive
params: dict[str, Any]
effects: Effects
source_info: source_info_util.SourceInfo
ctx: JaxprEqnContext
# It's slightly faster to use a class with __slots__ than a NamedTuple.
__slots__ = ['invars', 'outvars', 'primitive', 'params', 'effects',
'source_info', 'ctx']
def __init__(self, invars, outvars, primitive, params, effects, source_info,
ctx):
self.invars = invars
self.outvars = outvars
self.primitive = primitive
self.params = params
self.effects = effects
self.source_info = source_info
self.ctx = ctx
def __repr__(self):
return str(pp_eqn(self, JaxprPpContext(), JaxprPpSettings())).rstrip()
def replace(
self,
invars: list[Atom] | None = None,
outvars: list[Var] | None = None,
primitive: Primitive | None = None,
params: dict[str, Any] | None = None,
effects: Effects | None = None,
source_info: source_info_util.SourceInfo | None = None,
ctx: JaxprEqnContext | None = None
):
return JaxprEqn(
self.invars if invars is None else invars,
self.outvars if outvars is None else outvars,
self.primitive if primitive is None else primitive,
self.params if params is None else params,
self.effects if effects is None else effects,
self.source_info if source_info is None else source_info,
self.ctx if ctx is None else ctx,
)
# TODO(mattjj): call typecheck rules here, so we don't form bad eqns
def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None,
ctx=None):
source_info = source_info or source_info_util.new_source_info()
ctx = ctx or JaxprEqnContext(
compute_on.current_compute_type(),
config.threefry_partitionable.value,
xla_metadata_lib.current_xla_metadata())
if config.enable_checks.value:
assert all(isinstance(x, (Var, Literal)) for x in invars)
assert all(isinstance(v, Var) for v in outvars)
return JaxprEqn(invars, outvars, primitive, params, effects, source_info, ctx)
_var_counter = it.count()
@total_ordering
class Var:
__slots__ = ["count", "suffix", "aval"]
count: int
suffix: str
aval: AbstractValue
def __init__(self, suffix: str, aval: AbstractValue):
self.count = next(_var_counter)
self.suffix = suffix
self.aval = aval
# TODO(phawkins, mattjj): remove ordering of variables. JAX itself does not
# care about variable ordering, but the downstream package kfac_jax does.
def __lt__(self, other):
return self.count < other.count
def __repr__(self):
return f'Var(id={id(self)}){self.suffix}:{self.aval.str_short()}'
def gensym(suffix: str = '') -> Callable[[AbstractValue], Var]:
"""Produce distinct variables, printed with the optional suffix."""
return partial(Var, suffix)
# In a jaxpr, `dropvar` can appear in place of a bound variable to indicate that
# the assignment is dropped, i.e. that an expression's output value will never
# be read. In that sense, `dropvar` is not a variable, but it is convenient to
# treat it as a special case of one. Its `aval` is similarly inexact.
class DropVar(Var):
def __init__(self, aval: AbstractValue):
super().__init__('', aval)
def __repr__(self): return '_'
class Literal:
__slots__ = ["val", "aval", "hash"]
val: Any
aval: AbstractValue
hash: int | None
def __init__(self, val, aval):
self.val = val
self.aval = aval
try:
self.hash = hash(val)
except TypeError:
if type(val) in literalable_types:
try:
self.hash = hash((val.item(), val.dtype))
except (TypeError, AttributeError, ValueError):
self.hash = None
__hash__ = None # type: ignore
def __repr__(self):
if hasattr(self, 'hash'):
return f'{self.val}'
else:
return f'Literal(val={self.val})'
literalable_types: set[type] = set()
Atom = Union[Var, Literal]
class Primitive:
name: str
# set for multi-output primitives.
multiple_results: bool = False
# set for call primitives processed in final style.
call_primitive: bool = False
# set for map primitives processed in final style.
map_primitive: bool = False
def __init__(self, name: str):
self.name = name
def __repr__(self):
return f'{self.name}'
def bind(self, *args, **params):
for arg in args:
if (isinstance(arg, Tracer)
and not arg._trace.is_valid()
and not config.data_dependent_tracing_fallback.value):
raise escaped_tracer_error(arg)
# TODO: figure out how to handle function arguments
# assert (not config.enable_checks.value or
# all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
with take_current_trace() as cur_trace:
return self.bind_with_trace(cur_trace, args, params)
def bind_with_trace(self, trace, args, params):
return trace.process_primitive(self, args, params)
def def_impl(self, impl):
self.impl = impl
return impl
def def_abstract_eval(self, abstract_eval):
self.abstract_eval = _effect_free_abstract_eval(abstract_eval)
return abstract_eval
def def_effectful_abstract_eval(self, effectful_abstract_eval):
self.abstract_eval = effectful_abstract_eval
return effectful_abstract_eval
def def_bind_with_trace(self, bind_with_trace):
self.bind_with_trace = bind_with_trace
return bind_with_trace
def impl(self, *args, **params):
raise NotImplementedError("Evaluation rule for '{}' not implemented"
.format(self.name))
def abstract_eval(self, *args, **params):
raise NotImplementedError("Abstract evaluation for '{}' not implemented"
.format(self.name))
def get_bind_params(self, params):
return [], params
def _effect_free_abstract_eval(abstract_eval):
def abstract_eval_(*args, **kwargs):
return abstract_eval(*args, **kwargs), no_effects
return abstract_eval_
# -------------------- lifting --------------------
# TODO(mattjj): replace this approach with a primitive-keyed table of rules
def traverse_jaxpr_params(f, params):
"""Applies f to each jaxpr parameter and returns a tuple of returned values."""
return {name: f(p)
for name, param in params.items()
for p in (param if isinstance(param, (tuple, list)) else [param])
if type(p) in (Jaxpr, ClosedJaxpr)}
def eval_jaxpr(jaxpr: Jaxpr, consts, *args, propagate_source_info=True) -> list[Any]:
def read(v: Atom) -> Any:
return v.val if isinstance(v, Literal) else env[v]
def write(v: Var, val: Any) -> None:
if config.enable_checks.value and not config.dynamic_shapes.value:
assert typecheck(v.aval, val), (v.aval, val)
env[v] = val
env: dict[Var, Any] = {}
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
lu = last_used(jaxpr)
for eqn in jaxpr.eqns:
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
traceback = eqn.source_info.traceback if propagate_source_info else None
with source_info_util.user_context(
traceback, name_stack=name_stack), eqn.ctx.manager:
ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
if eqn.primitive.multiple_results:
map(write, eqn.outvars, ans)
else:
write(eqn.outvars[0], ans)
clean_up_dead_vars(eqn, env, lu)
return map(read, jaxpr.outvars)
# -------------------- tracing --------------------
TracerType = TypeVar('TracerType', bound='Tracer')
class Trace(Generic[TracerType]):
def process_primitive(self, primitive, tracers, params):
raise NotImplementedError("must override")
def invalidate(self):
self._invalidated = True
def is_valid(self):
return not hasattr(self, "_invalidated")
def __repr__(self):
return '{}'.format(self.__class__.__name__)
def process_call(self, call_primitive, f, tracers, params):
msg = (f"{type(self)} must override process_call to handle call-like "
"primitives")
raise NotImplementedError(msg)
def process_map(self, map_primitive, f, tracers, params):
msg = (f"{type(self)} must override process_map to handle map-like "
"primitives")
raise NotImplementedError(msg)
def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *,
symbolic_zeros):
msg = (f"{type(self)} must override process_custom_jvp_call "
"to handle custom_jvp primitives")
raise NotImplementedError(msg)
def process_custom_transpose(self, prim, call, tracers, **params):
msg = (f"{type(self)} must override process_custom_transpose "
"to handle custom_transpose_call primitives")
raise NotImplementedError(msg)
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers,
out_trees, symbolic_zeros):
msg = (f"{type(self)} must override process_custom_vjp_call "
"to handle custom_vjp primitives")
raise NotImplementedError(msg)
# TODO(dougalm): deprecate/delete
def full_raise(self, x):
return x
# TODO(dougalm): deprecate/delete
@property
def main(self):
return getattr(self, "tag", None)
def escaped_tracer_error(tracer, detail=None):
num_frames = _TRACER_ERROR_NUM_TRACEBACK_FRAMES.value
msg = ('Encountered an unexpected tracer. A function transformed by JAX '
'had a side effect, allowing for a reference to an intermediate value '
f'with type {tracer.aval.str_short()} wrapped in a '
f'{type(tracer).__name__} to escape the scope of the transformation.\n'
'JAX transformations require that functions explicitly return their '
'outputs, and disallow saving intermediate values to global state.')
dbg = getattr(tracer, '_debug_info', None)
if dbg is not None:
msg += ('\nThe function being traced when the value leaked was '
f'{dbg.func_src_info} traced for {dbg.traced_for}.')
line_info = getattr(tracer, '_line_info', None)
if line_info is not None:
divider = '\n' + '-'*30 + '\n'
msg += divider
msg += ('The leaked intermediate value was created on line '
f'{source_info_util.summarize(line_info)}. ')
msg += divider
if num_frames > 0:
msg += (f'When the value was created, the final {num_frames} stack '
'frames (most recent last) excluding JAX-internal frames were:')
msg += divider + source_info_util.summarize(
line_info, num_frames=num_frames) + divider
msg += ('\nTo catch the leak earlier, try setting the environment variable '
'JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context '
'manager.')
if detail:
msg += f'Detail: {detail}'
return UnexpectedTracerError(msg)
def check_scalar_conversion(arr: Array):
if arr.ndim > 0:
raise TypeError("Only scalar arrays can be converted to Python scalars; "
f"got {arr.ndim=}")
def check_integer_conversion(arr: Array):
if not (arr.shape == () and dtypes.issubdtype(arr.dtype, np.integer)):
raise TypeError("Only integer scalar arrays can be converted to a scalar index.")
def check_bool_conversion(arr: Array):
if arr.size == 0:
raise ValueError("The truth value of an empty array is ambiguous. Use"
" `array.size > 0` to check that an array is not empty.")
if arr.size > 1:
raise ValueError("The truth value of an array with more than one element"
" is ambiguous. Use a.any() or a.all()")
def _aval_property(name):
return property(lambda self: getattr(self.aval, name))
class Tracer(typing.Array, metaclass=StrictABCMeta):
__array_priority__ = 1000
__slots__ = ['_trace', '_line_info']
__hash__ = None # type: ignore
dtype = _aval_property('dtype')
ndim = _aval_property('ndim')
size = _aval_property('size')
shape = _aval_property('shape')
def __init__(self, trace: Trace):
self._trace = trace
def _error_repr(self):
if self.aval is None:
return f"traced array with aval {self.aval}"
return f"traced array with shape {self.aval.str_short()}"
def __array__(self, *args, **kw):
raise TracerArrayConversionError(self)
def __dlpack__(self, *args, **kw):
raise ConcretizationTypeError(self,
f"The __dlpack__() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
def tolist(self):
raise ConcretizationTypeError(self,
f"The tolist() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
def tobytes(self, order="C"):
del order
raise ConcretizationTypeError(self,
f"The tobytes() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
# TODO(dougalm): deprecate/delete
def full_lower(self):
raise NotImplementedError("must override: ", type(self))
def __iter__(self):
return iter(self.aval._iter(self))
def __reversed__(self):
return iter(self[::-1])
def __len__(self):
return self.aval._len(self)
def to_concrete_value(self):
# Should return the concrete value if there is one, or else None.
return None
@property
def sharding(self):
# This attribute is part of the jax.Array API, but only defined on concrete arrays.
# Raising a ConcretizationTypeError would make sense, but for backward compatibility
# we raise an AttributeError so that hasattr() and getattr() work as expected.
raise AttributeError(self,
f"The 'sharding' attribute is not available on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def committed(self):
raise ConcretizationTypeError(
self,
f"The 'committed' attribute is not available on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def device(self):
# This attribute is part of the jax.Array API, but only defined on concrete arrays.
# Raising a ConcretizationTypeError would make sense, but for backward compatibility
# we raise an AttributeError so that hasattr() and getattr() work as expected.
raise AttributeError(self,
f"The 'device' attribute is not available on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def addressable_shards(self):
raise ConcretizationTypeError(self,
f"The 'addressable_shards' attribute is not available on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def at(self):
return self.aval.at.fget(self)
@property
def aval(self):
raise NotImplementedError("must override")
def get_referent(self) -> Any:
return self # Override for object equivalence checking
def __bool__(self):
if is_concrete(self): return bool(self.to_concrete_value()) # pytype: disable=wrong-arg-types
check_bool_conversion(self)
return self.aval._bool(self)
def __int__(self):
if is_concrete(self): return int(self.to_concrete_value()) # pytype: disable=wrong-arg-types
check_scalar_conversion(self)
return self.aval._int(self)
def __float__(self):
check_scalar_conversion(self)
return self.aval._float(self)
def __complex__(self):
check_scalar_conversion(self)
return self.aval._complex(self)
def __hex__(self):
if is_concrete(self): return hex(self.to_concrete_value()) # pytype: disable=wrong-arg-types
check_integer_conversion(self)
return self.aval._hex(self)
def __oct__(self):
if is_concrete(self): return oct(self.to_concrete_value()) # pytype: disable=wrong-arg-types
check_integer_conversion(self)
return self.aval._oct(self)
def __index__(self):
if is_concrete(self): return operator.index(self.to_concrete_value()) # pytype: disable=wrong-arg-types
check_integer_conversion(self)
return self.aval._index(self)
# raises a useful error on attempts to pickle a Tracer.
def __reduce__(self):
raise ConcretizationTypeError(
self, ("The error occurred in the __reduce__ method, which may "
"indicate an attempt to serialize/pickle a traced value."))
# raises the better error message from ShapedArray
def __setitem__(self, idx, val): return self.aval._setitem(self, idx, val)
# NumPy also only looks up special methods on classes.
def __array_module__(self, types): return self.aval._array_module(self, types)
def __getattr__(self, name):
# if the aval property raises an AttributeError, gets caught here
assert not config.enable_checks.value or name != "aval"
try:
attr = getattr(self.aval, name)
except AttributeError as err:
raise AttributeError(
f"{self.__class__.__name__} has no attribute {name}"
) from err
else:
t = type(attr)
if t is aval_property:
return attr.fget(self)
elif t is aval_method:
return types.MethodType(attr.fun, self)
else:
return attr
def _pretty_print(self):
base = pp.text(f'Traced<{self.aval}>with<{self._trace}>')
contents = [(name, attr._pretty_print() if isinstance(attr, Tracer)
else pp.text(repr(attr))) for name, attr in self._contents()]
if contents:
base = pp.group(pp.nest(2, pp.concat([
base, pp.text(' with'), pp.brk(), pp.join(pp.brk(), [
pp.text(f'{name} = ') + pp_payload
for name, pp_payload in contents])
])))
return base
def __repr__(self):
return self._pretty_print().format()
def _contents(self):
try:
return [(name, getattr(self, name)) for name in self.__slots__]
except AttributeError:
return ()
def _origin_msg(self) -> str:
return ""
# Methods that are only valid for materialized arrays
def addressable_data(self, index):
raise ConcretizationTypeError(self,
f"The addressable_data() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def block_until_ready(self):
# Raise AttributeError for backward compatibility with hasattr() and getattr() checks.
raise AttributeError(self,
f"The 'block_until_ready' method is not available on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def copy_to_host_async(self):
# Raise AttributeError for backward compatibility with hasattr() and getattr() checks.
raise AttributeError(self,
f"The 'copy_to_host_async' method is not available on {self._error_repr()}."
f"{self._origin_msg()}")
def delete(self):
raise ConcretizationTypeError(self,
f"The delete() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
def devices(self):
raise ConcretizationTypeError(self,
f"The devices() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def global_shards(self):
raise ConcretizationTypeError(self,
f"The global_shards property was called on {self._error_repr()}."
f"{self._origin_msg()}")
def is_deleted(self):
raise ConcretizationTypeError(self,
f"The is_deleted() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def is_fully_addressable(self):
raise ConcretizationTypeError(self,
f"The is_fully_addressable property was called on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def is_fully_replicated(self):
raise ConcretizationTypeError(self,
f"The is_fully_replicated property was called on {self._error_repr()}."
f"{self._origin_msg()}")
def on_device_size_in_bytes(self):
raise ConcretizationTypeError(self,
f"The on_device_size_in_bytes() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def traceback(self):
raise ConcretizationTypeError(self,
f"The traceback property was called on {self._error_repr()}."
f"{self._origin_msg()}")
def unsafe_buffer_pointer(self):
raise ConcretizationTypeError(self,
f"The unsafe_buffer_pointer() method was called on {self._error_repr()}."
f"{self._origin_msg()}")
# these can be used to set up forwarding of properties and instance methods from
# Tracer instances to the underlying avals
aval_property = namedtuple("aval_property", ["fget"])
aval_method = namedtuple("aval_method", ["fun"])
class EvalTrace(Trace):
def process_primitive(self, primitive, args, params):
if config.debug_key_reuse.value:
# Import here to avoid circular imports
from jax.experimental.key_reuse._core import call_impl_with_key_reuse_checks # pytype: disable=import-error
return call_impl_with_key_reuse_checks(primitive, primitive.impl, *args, **params)
else:
# TODO(dougalm): delete. this shouldn't be necessary
args = map(full_lower, args)
for arg in args:
if isinstance(arg, Tracer):
if config.data_dependent_tracing_fallback.value:
return primitive.bind_with_trace(arg._trace, args, params)
else:
raise escaped_tracer_error(arg)
return primitive.impl(*args, **params)
def process_call(self, primitive, f, tracers, params):
if config.debug_key_reuse.value:
# Import here to avoid circular imports
from jax.experimental.key_reuse._core import call_impl_with_key_reuse_checks # pytype: disable=import-error
return call_impl_with_key_reuse_checks(primitive, primitive.impl, f, *tracers, **params)
else:
return primitive.impl(f, *tracers, **params)
process_map = process_call
def process_custom_transpose(self, primitive, call, tracers, **_):
del primitive, _
return call.call_wrapped(*tracers)
def process_custom_jvp_call(self, primitive, fun, jvp, tracers, **_):
del primitive, jvp, _ # Unused.
return fun.call_wrapped(*tracers)
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, **_): # pytype: disable=signature-mismatch
del primitive, fwd, bwd, _ # Unused.
return fun.call_wrapped(*tracers)
class TraceTag:
# TODO: this works for surprisingly subtle reasons. Function transformations
# like `jvp_subtrace` are parameterized by a tag that identifies the set of
# pre-existing tracers we want to unpack during the transformation. A function
# defined in an outer scope can't have any closed-over traces, so the tag is
# irrelevant. A function defined in the current scope may have closed-over
# traces, but the tag will never change so we'll never get a spurious cache
# hit. The plan is to do away with `lu.cache` altogether, and use a simpler
# caching scheme that only caches top-level functions. Then we can remove this
# hack.
def __hash__(self):
return hash(TraceTag)
def __eq__(self, other):
return isinstance(other, TraceTag)
ParamDict = dict[str, Any]
AxisName = Hashable
no_axis_name = object()
@dataclass(frozen=True)
class AxisEnv:
axis_sizes : dict[AxisName, int]
spmd_axis_names : set[AxisName]
def axis_size(self, axis_name):
if axis_name not in self.axis_sizes:
raise NameError(f"unbound axis name: {axis_name}")
else:
return self.axis_sizes[axis_name]
def axis_exists(self, axis_name):
return axis_name in self.axis_sizes
def axis_names(self):
return tuple(k for k in self.axis_sizes)
def pop_pure(self, axis_name):
new_sizes = self.axis_sizes.copy()
new_sizes.pop(axis_name)
return AxisEnv(new_sizes, self.spmd_axis_names)
def extend_pure(self, name_size_pairs):
new_sizes = self.axis_sizes.copy()
new_sizes.update((name, size) for name, size in name_size_pairs
if name is not no_axis_name)
return AxisEnv(new_sizes, self.spmd_axis_names)
def add_spmd_axis_names(self, axis_names):
new_spmd_axis_names = self.spmd_axis_names | set(axis_names)
return AxisEnv(self.axis_sizes, new_spmd_axis_names)
def as_hashable_key(self):
return tuple((name, size) for (name, size) in self.axis_sizes.items()
if name is not no_axis_name)
eval_trace = EvalTrace()
top_axis_env = AxisEnv({}, set())
class TracingContext(threading.local):
trace: Trace | None
axis_env : AxisEnv
def __init__(self):
self.reset()