@@ -836,45 +836,40 @@ def make_node(self, a, b):
836
836
837
837
if 0 in (a .ndim , b .ndim ):
838
838
raise ValueError ("inputs to `matmul` cannot be scalar." )
839
- elif a .ndim >= 2 and b .ndim >= 2 :
840
- out = at .TensorType (
841
- dtype = self .dtype ,
842
- broadcastable = (False ,) * max (a .ndim , b .ndim ),
843
- )()
844
- else :
845
- out_ndim = b .ndim - 1 if a .ndim == 1 else a .ndim - 1
846
- out = at .TensorType (
847
- dtype = self .dtype ,
848
- broadcastable = (False ,) * out_ndim ,
849
- )()
839
+
840
+ out_shape = self ._get_output_shape (a , b , (a .type .shape , b .type .shape ))
841
+ out = at .TensorType (dtype = self .dtype , shape = out_shape )()
850
842
return Apply (self , [a , b ], [out ])
851
843
852
844
def perform (self , node , inputs , outputs ):
853
845
x1 , x2 = inputs
854
846
outputs [0 ][0 ] = np .matmul (x1 , x2 , dtype = self .dtype )
855
847
856
- def infer_shape (self , fgraph , node , shapes ):
848
+ def _get_output_shape (self , x1 , x2 , shapes ):
857
849
x1_shape , x2_shape = shapes
858
- x1 , x2 = node .inputs
859
850
860
851
if x1 .ndim == 1 and x2 .ndim == 1 :
861
- return [()]
852
+ return ()
862
853
elif x1 .ndim == 1 and x2 .ndim > 1 :
863
- return [ x2_shape [:- 2 ] + x2_shape [- 1 :] ]
854
+ return x2_shape [:- 2 ] + x2_shape [- 1 :]
864
855
elif x1 .ndim > 1 and x2 .ndim == 1 :
865
- return [ x1_shape [:- 1 ] ]
856
+ return x1_shape [:- 1 ]
866
857
elif x1 .ndim == 2 and x2 .ndim == 2 :
867
- return [ x1_shape [:- 1 ] + x2_shape [- 1 :] ]
858
+ return x1_shape [:- 1 ] + x2_shape [- 1 :]
868
859
elif x1 .ndim > 2 and x2 .ndim == 2 :
869
- return [ x1_shape [:- 2 ] + x1_shape [- 2 :- 1 ] + x2_shape [- 1 :] ]
860
+ return x1_shape [:- 2 ] + x1_shape [- 2 :- 1 ] + x2_shape [- 1 :]
870
861
elif x1 .ndim == 2 and x2 .ndim > 2 :
871
- return [ x2_shape [:- 2 ] + x1_shape [- 2 :- 1 ] + x2_shape [- 1 :] ]
862
+ return x2_shape [:- 2 ] + x1_shape [- 2 :- 1 ] + x2_shape [- 1 :]
872
863
else :
873
- return [
864
+ return (
874
865
broadcast_shape (x1_shape [:- 2 ], x2_shape [:- 2 ], arrays_are_shapes = True )
875
866
+ x1_shape [- 2 :- 1 ]
876
867
+ x2_shape [- 1 :]
877
- ]
868
+ )
869
+
870
+ def infer_shape (self , fgraph , node , shapes ):
871
+ x1 , x2 = node .inputs
872
+ return [self ._get_output_shape (x1 , x2 , shapes )]
878
873
879
874
880
875
def matmul (x1 , x2 , dtype = None ):
0 commit comments