@@ -4659,41 +4659,56 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
46594659
46604660@pytest .mark .parametrize ("left_multiply" , [True , False ], ids = ["left" , "right" ])
46614661@pytest .mark .parametrize (
4662- "batch_left " , [True , False ], ids = ["batched_left " , "unbatched_left " ]
4662+ "batch_blockdiag " , [True , False ], ids = ["batch_blockdiag " , "unbatched_blockdiag " ]
46634663)
46644664@pytest .mark .parametrize (
4665- "batch_right " , [True , False ], ids = ["batched_right " , "unbatched_right " ]
4665+ "batch_other " , [True , False ], ids = ["batched_other " , "unbatched_other " ]
46664666)
4667- def test_local_block_diag_dot_to_dot_block_diag (left_multiply , batch_left , batch_right ):
4667+ def test_local_block_diag_dot_to_dot_block_diag (
4668+ left_multiply , batch_blockdiag , batch_other
4669+ ):
46684670 """
46694671 Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:]))
46704672 """
4673+
4674+ def has_blockdiag (graph ):
4675+ return any (
4676+ (
4677+ var .owner
4678+ and (
4679+ isinstance (var .owner .op , BlockDiagonal )
4680+ or (
4681+ isinstance (var .owner .op , Blockwise )
4682+ and isinstance (var .owner .op .core_op , BlockDiagonal )
4683+ )
4684+ )
4685+ )
4686+ for var in ancestors ([graph ])
4687+ )
4688+
46714689 a = tensor ("a" , shape = (4 , 2 ))
4672- b = tensor ("b" , shape = (2 , 4 ) if not batch_left else (3 , 2 , 4 ))
4690+ b = tensor ("b" , shape = (2 , 4 ) if not batch_blockdiag else (3 , 2 , 4 ))
46734691 c = tensor ("c" , shape = (4 , 4 ))
4674- d = tensor ("d" , shape = (10 , 10 ))
4675- e = tensor ("e" , shape = (10 , 10 ) if not batch_right else (3 , 1 , 10 , 10 ))
4676-
46774692 x = pt .linalg .block_diag (a , b , c )
46784693
4694+ d = tensor ("d" , shape = (10 , 10 ) if not batch_other else (3 , 1 , 10 , 10 ))
4695+
46794696 # Test multiple clients are all rewritten
46804697 if left_multiply :
4681- out = [ x @ d , x @ e ]
4698+ out = x @ d
46824699 else :
4683- out = [ d @ x , e @ x ]
4700+ out = d @ x
46844701
4685- with config .change_flags (optimizer_verbose = True ):
4686- fn = pytensor .function ([a , b , c , d , e ], out , mode = rewrite_mode )
4687-
4688- assert not any (
4689- isinstance (node .op , BlockDiagonal ) for node in fn .maker .fgraph .toposort ()
4690- )
4702+ assert has_blockdiag (out )
4703+ fn = pytensor .function ([a , b , c , d ], out , mode = rewrite_mode )
4704+ assert not has_blockdiag (fn .maker .fgraph .outputs [0 ])
46914705
46924706 fn_expected = pytensor .function (
4693- [a , b , c , d , e ],
4707+ [a , b , c , d ],
46944708 out ,
46954709 mode = Mode (linker = "py" , optimizer = None ),
46964710 )
4711+ assert has_blockdiag (fn_expected .maker .fgraph .outputs [0 ])
46974712
46984713 # TODO: Count Dots
46994714
@@ -4702,18 +4717,15 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply, batch_left, batch
47024717 b_val = rng .normal (size = b .type .shape ).astype (b .type .dtype )
47034718 c_val = rng .normal (size = c .type .shape ).astype (c .type .dtype )
47044719 d_val = rng .normal (size = d .type .shape ).astype (d .type .dtype )
4705- e_val = rng .normal (size = e .type .shape ).astype (e .type .dtype )
47064720
4707- rewrite_outs = fn (a_val , b_val , c_val , d_val , e_val )
4708- expected_outs = fn_expected (a_val , b_val , c_val , d_val , e_val )
4709-
4710- for out , expected in zip (rewrite_outs , expected_outs ):
4711- np .testing .assert_allclose (
4712- out ,
4713- expected ,
4714- atol = 1e-6 if config .floatX == "float32" else 1e-12 ,
4715- rtol = 1e-6 if config .floatX == "float32" else 1e-12 ,
4716- )
4721+ rewrite_out = fn (a_val , b_val , c_val , d_val )
4722+ expected_out = fn_expected (a_val , b_val , c_val , d_val )
4723+ np .testing .assert_allclose (
4724+ rewrite_out ,
4725+ expected_out ,
4726+ atol = 1e-6 if config .floatX == "float32" else 1e-12 ,
4727+ rtol = 1e-6 if config .floatX == "float32" else 1e-12 ,
4728+ )
47174729
47184730
47194731@pytest .mark .parametrize ("rewrite" , [True , False ], ids = ["rewrite" , "no_rewrite" ])
0 commit comments