77from pytensor .graph import Apply , FunctionGraph , Op , Type , node_rewriter
88from pytensor .graph .rewriting .basic import in2out
99from pytensor .scalar import constant
10- from pytensor .tensor import (
11- NoneConst ,
12- add ,
13- and_ ,
14- empty ,
15- get_scalar_constant_value ,
16- set_subtensor ,
17- )
10+ from pytensor .tensor import add , and_ , empty , get_scalar_constant_value , set_subtensor
1811from pytensor .tensor .exceptions import NotScalarConstantError
1912from pytensor .tensor .shape import Shape_i
13+ from pytensor .tensor .subtensor import Subtensor , get_idx_list
2014from pytensor .tensor .type import DenseTensorType , TensorType
2115from pytensor .tensor .type_other import NoneTypeT
16+ from pytensor .typed_list import GetItem , TypedListType , append , make_empty_list
2217
2318
2419def validate_loop_update_types (update ):
@@ -176,8 +171,7 @@ def __init__(
176171 )
177172 )
178173 else :
179- # We can't concatenate all types of states, such as RandomTypes
180- self .trace_types .append (NoneConst .type )
174+ self .trace_types .append (TypedListType (state_type ))
181175
182176 self .constant_types = [inp .type for inp in update_fg .inputs [self .n_states :]]
183177 self .n_constants = len (self .constant_types )
@@ -312,10 +306,6 @@ def scan(fn, idx, initial_states, constants, max_iters):
312306 if fgraph .clients [trace ]
313307 ]
314308
315- # Check that outputs that cannot be converted into sequences (such as RandomTypes) are not being referenced
316- for trace_idx in used_traces_idxs :
317- assert not isinstance (old_states [trace_idx ].type , NoneTypeT )
318-
319309 # Inputs to the new Loop
320310 max_iters = node .inputs [0 ]
321311 init_states = node .inputs [1 : 1 + op .n_states ]
@@ -324,6 +314,8 @@ def scan(fn, idx, initial_states, constants, max_iters):
324314 (max_iters , * tuple (init_states [trace_idx ].shape )),
325315 dtype = init_states [trace_idx ].dtype ,
326316 )
317+ if isinstance (init_states [trace_idx ].type , DenseTensorType )
318+ else make_empty_list (init_states [trace_idx ].type )
327319 for trace_idx in used_traces_idxs
328320 ]
329321 constants = node .inputs [1 + op .n_states :]
@@ -387,6 +379,8 @@ def scan(fn, idx, initial_states, constants, max_iters):
387379 inner_while_cond , * inner_next_states = update_fg .outputs
388380 inner_next_traces = [
389381 set_subtensor (prev_trace [inner_idx ], inner_next_states [trace_idx ])
382+ if isinstance (prev_trace .type , DenseTensorType )
383+ else append (prev_trace , inner_next_states [trace_idx ])
390384 for trace_idx , prev_trace in zip (used_traces_idxs , inner_traces )
391385 ]
392386 for t in inner_next_traces :
@@ -429,7 +423,7 @@ def scan(fn, idx, initial_states, constants, max_iters):
429423 replacements = dict (zip (old_states , new_states ))
430424 for trace_idx , new_trace in zip (used_traces_idxs , new_traces ):
431425 # If there is no while condition, the whole trace will be used
432- if op .has_while_condition :
426+ if op .has_while_condition and isinstance ( new_trace . type , DenseTensorType ) :
433427 new_trace = new_trace [:final_idx ]
434428 replacements [old_traces [trace_idx ]] = new_trace
435429 return replacements
@@ -446,3 +440,39 @@ def scan(fn, idx, initial_states, constants, max_iters):
446440 "not_jax" ,
447441 position = 1.0 ,
448442)
443+
444+
445+ @node_rewriter ([Scan ])
446+ def scan_view_last_state (fgraph , node ):
447+ """Replace trace[-1] by the last state output of a Scan node"""
448+ replacements = {}
449+ for final_state , trace in zip (
450+ node .outputs [: node .op .n_states ], node .outputs [node .op .n_states :]
451+ ):
452+ clients = fgraph .clients [trace ]
453+ for client , _ in clients :
454+ if client == "output" :
455+ continue
456+ if isinstance (client .op , (Subtensor , GetItem )):
457+ if isinstance (client .op , Subtensor ):
458+ idxs = get_idx_list (client .inputs , client .op .idx_list )
459+ if len (idxs ) == 1 :
460+ idx = idxs [0 ]
461+ else :
462+ idx = client .inputs [1 ]
463+ try :
464+ last_index = get_scalar_constant_value (idx ) == - 1
465+ except NotScalarConstantError :
466+ continue
467+ if last_index :
468+ replacements [client .default_output ()] = final_state
469+ return replacements
470+
471+
472+ optdb .register (
473+ "scan_view_last_state" ,
474+ in2out (scan_view_last_state ),
475+ "fast_compile" ,
476+ "fast_run" ,
477+ position = 0.999 ,
478+ )
0 commit comments