Skip to content

Commit fcb2d5c

Browse files
committed
simplify logic by making use of tensor.type.shape and removing broadcastable
1 parent c824a82 commit fcb2d5c

File tree

2 files changed

+18
-23
lines changed

2 files changed

+18
-23
lines changed

aesara/tensor/nlinalg.py

+16-21
Original file line numberDiff line numberDiff line change
@@ -836,45 +836,40 @@ def make_node(self, a, b):
836836

837837
if 0 in (a.ndim, b.ndim):
838838
raise ValueError("inputs to `matmul` cannot be scalar.")
839-
elif a.ndim >= 2 and b.ndim >= 2:
840-
out = at.TensorType(
841-
dtype=self.dtype,
842-
broadcastable=(False,) * max(a.ndim, b.ndim),
843-
)()
844-
else:
845-
out_ndim = b.ndim - 1 if a.ndim == 1 else a.ndim - 1
846-
out = at.TensorType(
847-
dtype=self.dtype,
848-
broadcastable=(False,) * out_ndim,
849-
)()
839+
840+
out_shape = self._get_output_shape(a, b, (a.type.shape, b.type.shape))
841+
out = at.TensorType(dtype=self.dtype, shape=out_shape)()
850842
return Apply(self, [a, b], [out])
851843

852844
def perform(self, node, inputs, outputs):
853845
x1, x2 = inputs
854846
outputs[0][0] = np.matmul(x1, x2, dtype=self.dtype)
855847

856-
def infer_shape(self, fgraph, node, shapes):
848+
def _get_output_shape(self, x1, x2, shapes):
857849
x1_shape, x2_shape = shapes
858-
x1, x2 = node.inputs
859850

860851
if x1.ndim == 1 and x2.ndim == 1:
861-
return [()]
852+
return ()
862853
elif x1.ndim == 1 and x2.ndim > 1:
863-
return [x2_shape[:-2] + x2_shape[-1:]]
854+
return x2_shape[:-2] + x2_shape[-1:]
864855
elif x1.ndim > 1 and x2.ndim == 1:
865-
return [x1_shape[:-1]]
856+
return x1_shape[:-1]
866857
elif x1.ndim == 2 and x2.ndim == 2:
867-
return [x1_shape[:-1] + x2_shape[-1:]]
858+
return x1_shape[:-1] + x2_shape[-1:]
868859
elif x1.ndim > 2 and x2.ndim == 2:
869-
return [x1_shape[:-2] + x1_shape[-2:-1] + x2_shape[-1:]]
860+
return x1_shape[:-2] + x1_shape[-2:-1] + x2_shape[-1:]
870861
elif x1.ndim == 2 and x2.ndim > 2:
871-
return [x2_shape[:-2] + x1_shape[-2:-1] + x2_shape[-1:]]
862+
return x2_shape[:-2] + x1_shape[-2:-1] + x2_shape[-1:]
872863
else:
873-
return [
864+
return (
874865
broadcast_shape(x1_shape[:-2], x2_shape[:-2], arrays_are_shapes=True)
875866
+ x1_shape[-2:-1]
876867
+ x2_shape[-1:]
877-
]
868+
)
869+
870+
def infer_shape(self, fgraph, node, shapes):
871+
x1, x2 = node.inputs
872+
return [self._get_output_shape(x1, x2, shapes)]
878873

879874

880875
def matmul(x1, x2, dtype=None):

tests/tensor/test_nlinalg.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,8 @@ def test_infer_shape(self):
183183
((3, 2, 4, 5), (1, 5, 7)),
184184
((4, 5, 8, 2, 3), (3, 2, 3, 4, 5, 8, 3, 3)),
185185
]:
186-
a = tensor(dtype=config.floatX, broadcastable=[i == 1 for i in shape_x1])
187-
b = tensor(dtype=config.floatX, broadcastable=[i == 1 for i in shape_x2])
186+
a = tensor(dtype=config.floatX, shape=shape_x1)
187+
b = tensor(dtype=config.floatX, shape=shape_x2)
188188
x1 = self.rng.random(shape_x1).astype(config.floatX)
189189
x2 = self.rng.random(shape_x2).astype(config.floatX)
190190

0 commit comments

Comments
 (0)