3
3
from copy import deepcopy
4
4
import itertools
5
5
from collections import OrderedDict
6
- from operator import itemgetter
7
6
from functools import reduce
8
7
import typing as ty
9
8
from . import helpers_state as hlpst
@@ -41,7 +40,7 @@ def __init__(self, indices: dict[str, int] | None = None):
41
40
if indices is None :
42
41
self .indices = OrderedDict ()
43
42
else :
44
- self .indices = OrderedDict (sorted ( indices .items () ))
43
+ self .indices = OrderedDict (indices .items ())
45
44
46
45
def __len__ (self ) -> int :
47
46
return len (self .indices )
@@ -53,13 +52,12 @@ def __getitem__(self, key: str) -> int:
53
52
return self .indices [key ]
54
53
55
54
def __lt__ (self , other : "StateIndex" ) -> bool :
56
- if set (self .indices ) != set (other .indices ):
55
+ if list (self .indices ) != list (other .indices ):
57
56
raise ValueError (
58
- f"StateIndex { self } does not contain the same indices as { other } "
57
+ f"StateIndex { self } does not contain the same indices in the same order "
58
+ f"as { other } : { list (self .indices )} != { list (other .indices )} "
59
59
)
60
- return sorted (self .indices .items (), key = itemgetter (0 )) < sorted (
61
- other .indices .items (), key = itemgetter (0 )
62
- )
60
+ return tuple (self .indices .items ()) < tuple (other .indices .items ())
63
61
64
62
def __repr__ (self ) -> str :
65
63
return (
@@ -273,24 +271,29 @@ def depth(self, after_combine: bool = True) -> int:
273
271
int
274
272
number of splits in the state (i.e. linked splits only add 1)
275
273
"""
276
- depth = 0
277
- stack = []
278
274
279
- def included (s ):
280
- return s not in self .combiner if after_combine else True
275
+ # replace field names with 1 or 0 (1 if the field is included in the state)
276
+ include_rpn = [
277
+ (
278
+ s
279
+ if s in ["." , "*" ]
280
+ else (int (s not in self .combiner ) if after_combine else 1 )
281
+ )
282
+ for s in self .splitter_rpn
283
+ ]
281
284
282
- for spl in self . splitter_rpn :
283
- if spl in [ "." , "*" ] :
284
- if spl == "." :
285
- depth += int ( all ( included ( s ) for s in stack ))
286
- else :
287
- assert spl == "*"
288
- depth += len ([ s for s in stack if included ( s )])
289
- stack = []
285
+ stack = []
286
+ for opr in include_rpn :
287
+ if opr == "." :
288
+ assert len ( stack ) >= 2
289
+ stack . append ( stack . pop () and stack . pop ())
290
+ elif opr == "*" :
291
+ assert len (stack ) >= 2
292
+ stack . append ( stack . pop () + stack . pop ())
290
293
else :
291
- stack .append (spl )
292
- remaining_stack = [ s for s in stack if included ( s )]
293
- return depth + len ( remaining_stack )
294
+ stack .append (opr )
295
+ assert len ( stack ) == 1
296
+ return stack [ 0 ]
294
297
295
298
def nest_output_type (self , type_ : type ) -> type :
296
299
"""Nests a type of an output field in a combination of lists and state-arrays
0 commit comments