From 01a4af2aad166a77bf85b041fe466aff87d21818 Mon Sep 17 00:00:00 2001
From: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Date: Mon, 25 Mar 2024 11:33:00 +0100
Subject: [PATCH] Fix bug when broadcasting branches in local_useless_switch
 rewrite

---
 pytensor/tensor/rewriting/basic.py   | 19 ++++++++-----------
 tests/tensor/rewriting/test_basic.py | 19 +++++++++++++++++++
 2 files changed, 27 insertions(+), 11 deletions(-)

diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py
index 17fd595fdb..33bf799058 100644
--- a/pytensor/tensor/rewriting/basic.py
+++ b/pytensor/tensor/rewriting/basic.py
@@ -1024,18 +1024,15 @@ def local_useless_switch(fgraph, node):
 
     # if left is right -> left
     if equivalent_up_to_constant_casting(left, right):
-        if left.type.broadcastable == out_bcast:
-            out_dtype = node.outputs[0].type.dtype
-            if left.type.dtype != out_dtype:
-                left = cast(left, out_dtype)
-                copy_stack_trace(node.outputs + left, left)
-            # When not casting, the other inputs of the switch aren't needed in the traceback
-            return [left]
+        if left.type.broadcastable != out_bcast:
+            left, _ = broadcast_arrays(left, cond)
 
-        else:
-            ret = broadcast_arrays(left, cond)[0]
-            copy_stack_trace(node.outputs + left, ret)
-            return [ret]
+        out_dtype = node.outputs[0].type.dtype
+        if left.type.dtype != out_dtype:
+            left = cast(left, out_dtype)
+
+        copy_stack_trace(node.outputs + node.inputs, left)
+        return [left]
 
     # This case happens with scan.
     # Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py
index f0a70dd5ee..dc1f3e099a 100644
--- a/tests/tensor/rewriting/test_basic.py
+++ b/tests/tensor/rewriting/test_basic.py
@@ -1089,6 +1089,25 @@ def test_broadcasting_3(self):
         assert isinstance(f.maker.fgraph.outputs[0].owner.op, Alloc)
         assert not any(node.op == pt.switch for node in f.maker.fgraph.toposort())
 
+    def test_broadcasting_different_dtype(self):
+        cond = vector("x", dtype="bool")
+        float32_branch = as_tensor(np.array([0], dtype="float32"))
+        float64_branch = as_tensor(np.array([0], dtype="float64"))
+
+        out = pt.switch(cond, float32_branch, float64_branch)
+        expected_out = pt.alloc(float64_branch, cond.shape)
+
+        rewritten_out = rewrite_graph(
+            out, include=("canonicalize", "stabilize", "specialize")
+        )
+        assert equal_computations([rewritten_out], [expected_out])
+
+        out = pt.switch(cond, float64_branch, float32_branch)
+        rewritten_out = rewrite_graph(
+            out, include=("canonicalize", "stabilize", "specialize")
+        )
+        assert equal_computations([rewritten_out], [expected_out])
+
 
 class TestLocalMergeSwitchSameCond:
     @pytest.mark.parametrize(