Skip to content

Commit 55bb36f

Browse files
bloopstensorflower-gardener
authored andcommitted
Add TF2XLA registration for Einsum op, so that it uses XlaEinsum implementation. (Try #2)
This is needed because we want to switch to the fused EinsumOp in TF Classic when from the python implementation of tf.einsum. After this, the registration would enable compilation of the new TF Graphs in XLA. Note: It will still emit XlaEinsum and not Einsum when enclosing context is TPU. PiperOrigin-RevId: 268083168
1 parent 83c5225 commit 55bb36f

File tree

2 files changed

+34
-9
lines changed

2 files changed

+34
-9
lines changed

tensorflow/compiler/tf2xla/kernels/einsum_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class EinsumOp : public XlaOpKernel {
4848
};
4949

5050
REGISTER_XLA_OP(Name("XlaEinsum").TypeConstraint("T", kEinsumTypes), EinsumOp);
51+
REGISTER_XLA_OP(Name("Einsum").TypeConstraint("T", kEinsumTypes), EinsumOp);
5152

5253
} // namespace
5354
} // namespace tensorflow

tensorflow/python/kernel_tests/einsum_op_test.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def _check(self, s, *input_shapes, **kwargs):
5050
b = self.evaluate(gen_linalg_ops.einsum(input_tensors, s))
5151
self.assertAllClose(a, b, atol=1e-4, rtol=1e-4)
5252

53+
@test_util.disable_xla('b/131919749')
5354
def testUnary(self):
5455
self._check('->', ())
5556
self._check('aa->', (3, 3))
@@ -66,7 +67,10 @@ def testUnary(self):
6667
self._check('aabcc->ac', (3, 3, 5, 4, 4))
6768
self._check('aabcd->ad', (3, 3, 5, 4, 4))
6869

70+
@test_util.disable_xla('b/131919749')
6971
def testUnaryEllipsis(self):
72+
# Unary cases with ellipsis.
73+
# Edge cases.
7074
self._check('...->...', ())
7175
self._check('...->', ())
7276
self._check('->...', ())
@@ -78,31 +82,44 @@ def testUnaryEllipsis(self):
7882
self._check('a...a->a...', (2, 1, 2))
7983
self._check('a...a->a...', (2, 3, 4, 5, 2))
8084

85+
# Regular cases.
8186
self._check('...ijk->...ki', (3, 4, 5))
8287
self._check('...ijk->...ki', (1, 3, 4, 5))
8388
self._check('...ijk->...ki', (2, 2, 3, 4, 5))
8489

8590
# Repeated indices.
8691
self._check('i...ii->...i', (3, 2, 3, 3))
8792

88-
def testBinary(self):
93+
def testBinarySimple(self):
94+
# Binary cases in XLA mode must have either (a) each index appearing exactly
95+
# once in both the inputs (batch or contraction index), or (b) appearing
96+
# exactly once in an input and in the output (free index).
8997
self._check(',->', (), ())
9098
self._check('a,a->', (3,), (3,))
9199
self._check('a,a->a', (3,), (3,))
92-
self._check('ba,b->', (3, 2), (3,))
93100
self._check('ab,b->a', (3, 4), (4,))
94101
self._check('ab,ab->', (3, 4), (3, 4))
102+
self._check('ab,bc->ac', (3, 4), (4, 5))
95103
self._check('nij,jk->nik', (5, 2, 3), (3, 4))
96104
self._check('abc,bad->abcd', (1, 2, 3), (2, 1, 4))
105+
# Based on https://github.com/google/jax/issues/37#issuecomment-448572187
106+
self._check('sa,shb->shab', (2, 1), (2, 3, 4))
107+
108+
@test_util.disable_xla('b/131919749')
109+
def testReducedIndices(self):
110+
self._check('ba,b->', (3, 2), (3,))
111+
self._check('ab,ab->', (3, 4), (3, 4))
112+
113+
@test_util.disable_xla('b/131919749')
114+
def testRepeatedIndices(self):
97115
# Repeated indices.
98116
self._check('ijj,k->ik', (2, 3, 3), (4,))
99117
self._check('aba,a->b', (3, 4, 3), (3,))
100118
# From https://github.com/dask/dask/pull/3412#discussion_r182413444
101119
self._check('aab,bc->ac', (2, 2, 3), (3, 4))
102120
self._check('aab,bcc->ac', (2, 2, 3), (3, 4, 4))
103-
# Based on https://github.com/google/jax/issues/37#issuecomment-448572187
104-
self._check('sa,shb->shab', (2, 1), (2, 3, 4))
105121

122+
@test_util.disable_xla('b/131919749')
106123
def testBroadcasting(self):
107124
# Batch matmul without broadcasting.
108125
self._check('...ij,...jk->...ik', (5, 1, 2, 3), (5, 1, 3, 4))
@@ -113,14 +130,17 @@ def testBroadcasting(self):
113130
self._check('...ij,...jk->...ik', (2, 3), (5, 3, 5))
114131
self._check('...ij,...jk->...ik', (3, 1, 2, 3), (1, 1, 7, 3, 5))
115132
self._check('i...j,j...k->...ik', (2, 1, 3, 1, 3), (3, 1, 7, 5))
133+
# Following 2 from # https://stackoverflow.com/a/19203475/1611416
134+
self._check('...abc,...abcd->...d', (1, 1, 2, 3, 4), (5, 2, 3, 4, 6))
135+
self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,))
136+
137+
@test_util.disable_xla('b/131919749')
138+
def testBroadcastingWithRepeatedIndices(self):
116139
# Broadcasting with repeated indices.
117140
self._check('ij,jk...k->i...', (3, 2), (2, 4, 1, 4))
118141
self._check('ij,jk...k->...i', (3, 2), (2, 4, 5, 4))
119142
self._check('ijj,jk...k->i...', (3, 2, 2), (2, 4, 1, 4))
120143
self._check('i...jj,jk...k->i...', (3, 3, 1, 2, 2), (2, 4, 1, 5, 4))
121-
# Following 2 from # https://stackoverflow.com/a/19203475/1611416
122-
self._check('...abc,...abcd->...d', (1, 1, 2, 3, 4), (5, 2, 3, 4, 6))
123-
self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,))
124144

125145
def testDtypes(self):
126146
bfloat16 = dtypes.bfloat16.as_numpy_dtype
@@ -153,6 +173,7 @@ def check(dtype):
153173
]:
154174
check(dtype)
155175

176+
@test_util.disable_xla('b/131919749')
156177
@test_util.run_in_graph_and_eager_modes
157178
def testInvalid(self):
158179
r = np.random.RandomState(0)
@@ -178,6 +199,7 @@ def testInvalid(self):
178199
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
179200
_ = self.evaluate(gen_linalg_ops.einsum(placeholders, args[0]))
180201

202+
@test_util.disable_xla('b/131919749')
181203
@test_util.run_in_graph_and_eager_modes
182204
def testPlaceholder(self):
183205

@@ -202,9 +224,11 @@ def check(equation, *input_and_placeholder_shapes):
202224
((4, 3), (None, 3)))
203225
check('...ij,...jk->...ik', ((3, 1, 2, 3), None), ((1, 7, 3, 4), None))
204226

227+
@test_util.disable_xla('b/131919749')
205228
def testOutputRepeatedLabels(self):
206-
# This is the reverse operation of repeated input labels, to be used for
207-
# computing symbolic gradients of einsum.
229+
# This is the reverse operation of generalized traces, to be used for
230+
# computing symbolic gradients of einsum. Note: this operation is not
231+
# supported by np.einsum as it's only required for gradients.
208232
r = np.random.RandomState(0)
209233
a = r.randn(2, 2)
210234
s = 'a->aa'

0 commit comments

Comments
 (0)