@@ -50,6 +50,7 @@ def _check(self, s, *input_shapes, **kwargs):
5050 b = self .evaluate (gen_linalg_ops .einsum (input_tensors , s ))
5151 self .assertAllClose (a , b , atol = 1e-4 , rtol = 1e-4 )
5252
53+ @test_util .disable_xla ('b/131919749' )
5354 def testUnary (self ):
5455 self ._check ('->' , ())
5556 self ._check ('aa->' , (3 , 3 ))
@@ -66,7 +67,10 @@ def testUnary(self):
6667 self ._check ('aabcc->ac' , (3 , 3 , 5 , 4 , 4 ))
6768 self ._check ('aabcd->ad' , (3 , 3 , 5 , 4 , 4 ))
6869
70+ @test_util .disable_xla ('b/131919749' )
6971 def testUnaryEllipsis (self ):
72+ # Unary cases with ellipsis.
73+ # Edge cases.
7074 self ._check ('...->...' , ())
7175 self ._check ('...->' , ())
7276 self ._check ('->...' , ())
@@ -78,31 +82,44 @@ def testUnaryEllipsis(self):
7882 self ._check ('a...a->a...' , (2 , 1 , 2 ))
7983 self ._check ('a...a->a...' , (2 , 3 , 4 , 5 , 2 ))
8084
85+ # Regular cases.
8186 self ._check ('...ijk->...ki' , (3 , 4 , 5 ))
8287 self ._check ('...ijk->...ki' , (1 , 3 , 4 , 5 ))
8388 self ._check ('...ijk->...ki' , (2 , 2 , 3 , 4 , 5 ))
8489
8590 # Repeated indices.
8691 self ._check ('i...ii->...i' , (3 , 2 , 3 , 3 ))
8792
88- def testBinary (self ):
93+ def testBinarySimple (self ):
94+ # Binary cases in XLA mode must have either (a) each index appearing exactly
95+ # once in both the inputs (batch or contraction index), or (b) appearing
96+ # exactly once in an input and in the output (free index).
8997 self ._check (',->' , (), ())
9098 self ._check ('a,a->' , (3 ,), (3 ,))
9199 self ._check ('a,a->a' , (3 ,), (3 ,))
92- self ._check ('ba,b->' , (3 , 2 ), (3 ,))
93100 self ._check ('ab,b->a' , (3 , 4 ), (4 ,))
94101 self ._check ('ab,ab->' , (3 , 4 ), (3 , 4 ))
102+ self ._check ('ab,bc->ac' , (3 , 4 ), (4 , 5 ))
95103 self ._check ('nij,jk->nik' , (5 , 2 , 3 ), (3 , 4 ))
96104 self ._check ('abc,bad->abcd' , (1 , 2 , 3 ), (2 , 1 , 4 ))
105+ # Based on https://github.com/google/jax/issues/37#issuecomment-448572187
106+ self ._check ('sa,shb->shab' , (2 , 1 ), (2 , 3 , 4 ))
107+
108+ @test_util .disable_xla ('b/131919749' )
109+ def testReducedIndices (self ):
110+ self ._check ('ba,b->' , (3 , 2 ), (3 ,))
111+ self ._check ('ab,ab->' , (3 , 4 ), (3 , 4 ))
112+
113+ @test_util .disable_xla ('b/131919749' )
114+ def testRepeatedIndices (self ):
97115 # Repeated indices.
98116 self ._check ('ijj,k->ik' , (2 , 3 , 3 ), (4 ,))
99117 self ._check ('aba,a->b' , (3 , 4 , 3 ), (3 ,))
100118 # From https://github.com/dask/dask/pull/3412#discussion_r182413444
101119 self ._check ('aab,bc->ac' , (2 , 2 , 3 ), (3 , 4 ))
102120 self ._check ('aab,bcc->ac' , (2 , 2 , 3 ), (3 , 4 , 4 ))
103- # Based on https://github.com/google/jax/issues/37#issuecomment-448572187
104- self ._check ('sa,shb->shab' , (2 , 1 ), (2 , 3 , 4 ))
105121
122+ @test_util .disable_xla ('b/131919749' )
106123 def testBroadcasting (self ):
107124 # Batch matmul without broadcasting.
108125 self ._check ('...ij,...jk->...ik' , (5 , 1 , 2 , 3 ), (5 , 1 , 3 , 4 ))
@@ -113,14 +130,17 @@ def testBroadcasting(self):
113130 self ._check ('...ij,...jk->...ik' , (2 , 3 ), (5 , 3 , 5 ))
114131 self ._check ('...ij,...jk->...ik' , (3 , 1 , 2 , 3 ), (1 , 1 , 7 , 3 , 5 ))
115132 self ._check ('i...j,j...k->...ik' , (2 , 1 , 3 , 1 , 3 ), (3 , 1 , 7 , 5 ))
133+ # Following 2 from # https://stackoverflow.com/a/19203475/1611416
134+ self ._check ('...abc,...abcd->...d' , (1 , 1 , 2 , 3 , 4 ), (5 , 2 , 3 , 4 , 6 ))
135+ self ._check ('ab...,b->ab...' , (2 , 3 , 1 , 1 , 5 ), (3 ,))
136+
137+ @test_util .disable_xla ('b/131919749' )
138+ def testBroadcastingWithRepeatedIndices (self ):
116139 # Broadcasting with repeated indices.
117140 self ._check ('ij,jk...k->i...' , (3 , 2 ), (2 , 4 , 1 , 4 ))
118141 self ._check ('ij,jk...k->...i' , (3 , 2 ), (2 , 4 , 5 , 4 ))
119142 self ._check ('ijj,jk...k->i...' , (3 , 2 , 2 ), (2 , 4 , 1 , 4 ))
120143 self ._check ('i...jj,jk...k->i...' , (3 , 3 , 1 , 2 , 2 ), (2 , 4 , 1 , 5 , 4 ))
121- # Following 2 from # https://stackoverflow.com/a/19203475/1611416
122- self ._check ('...abc,...abcd->...d' , (1 , 1 , 2 , 3 , 4 ), (5 , 2 , 3 , 4 , 6 ))
123- self ._check ('ab...,b->ab...' , (2 , 3 , 1 , 1 , 5 ), (3 ,))
124144
125145 def testDtypes (self ):
126146 bfloat16 = dtypes .bfloat16 .as_numpy_dtype
@@ -153,6 +173,7 @@ def check(dtype):
153173 ]:
154174 check (dtype )
155175
176+ @test_util .disable_xla ('b/131919749' )
156177 @test_util .run_in_graph_and_eager_modes
157178 def testInvalid (self ):
158179 r = np .random .RandomState (0 )
@@ -178,6 +199,7 @@ def testInvalid(self):
178199 with self .assertRaises ((ValueError , errors .InvalidArgumentError )):
179200 _ = self .evaluate (gen_linalg_ops .einsum (placeholders , args [0 ]))
180201
202+ @test_util .disable_xla ('b/131919749' )
181203 @test_util .run_in_graph_and_eager_modes
182204 def testPlaceholder (self ):
183205
@@ -202,9 +224,11 @@ def check(equation, *input_and_placeholder_shapes):
202224 ((4 , 3 ), (None , 3 )))
203225 check ('...ij,...jk->...ik' , ((3 , 1 , 2 , 3 ), None ), ((1 , 7 , 3 , 4 ), None ))
204226
227+ @test_util .disable_xla ('b/131919749' )
205228 def testOutputRepeatedLabels (self ):
206- # This is the reverse operation of repeated input labels, to be used for
207- # computing symbolic gradients of einsum.
229+ # This is the reverse operation of generalized traces, to be used for
230+ # computing symbolic gradients of einsum. Note: this operation is not
231+ # supported by np.einsum as it's only required for gradients.
208232 r = np .random .RandomState (0 )
209233 a = r .randn (2 , 2 )
210234 s = 'a->aa'
0 commit comments