Skip to content

Commit dc57189

Browse files
hawkinsptensorflower-gardener
authored andcommitted
[XLA] Use the compact WY representation in the implementation of blocked QR decompositions.
Compact WY representations are described in: Schreiber, Robert, and Charles Van Loan. "A storage-efficient WY representation for products of Householder transformations." SIAM Journal on Scientific and Statistical Computing 10.1 (1989): 53-57. The compact WY representation is more storage efficient, requiring calculation of an nxn triangular matrix, where n is the block size (e.g., 128), instead of an mxn matrix where m is the number of matrix rows. PiperOrigin-RevId: 330711085 Change-Id: Ideac239ff118ee6ac2fd1397b731a40e11d6ecd7
1 parent 1f50fde commit dc57189

File tree

1 file changed

+54
-43
lines changed
  • tensorflow/compiler/xla/client/lib

1 file changed

+54
-43
lines changed

tensorflow/compiler/xla/client/lib/qr.cc

Lines changed: 54 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -258,33 +258,36 @@ StatusOr<QRBlockResult> QRBlock(XlaOp a, PrecisionConfig::Precision precision) {
258258
return result;
259259
}
260260

261-
// Computes W and Y such that I-WY is equivalent to the sequence of Householder
262-
// transformations given by vs and taus.
263-
// Golub and van Loan, "Matrix Computations", algorithm 5.1.2.
261+
// Computes T such that (I - Y @ T @ Y^t) is a product of the elementary
262+
// Householder reflectors given by `vs` and `taus`.
263+
//
264+
// Schreiber, Robert, and Charles Van Loan. "A storage-efficient WY
265+
// representation for products of Householder transformations." SIAM Journal on
266+
// Scientific and Statistical Computing 10.1 (1989): 53-57.
267+
//
268+
// m, n = vs.shape[-2:]
269+
// t = np.zeros((n, n))
264270
// Y = np.zeros([m, n])
265-
// W = np.zeros([m, n])
271+
// t[0, 0] = -taus[0]
266272
// Y[:, 0] = vs[:, 0]
267-
// W[:, 0] = -taus[0] * vs[:, 0]
268-
// for j in xrange(1, n):
269-
// v = vs[:, j]
270-
// z = -taus[j] * v - taus[j] * np.dot(W, np.dot(Y.T, v))
271-
// W[:, j] = z
272-
// Y[:, j] = v
273-
// return W
274-
// There is no need to return Y since at termination of the loop it is equal to
275-
// vs.
276-
StatusOr<XlaOp> ComputeWYRepresentation(PrimitiveType type,
273+
// for i in range(1, n):
274+
// z = -taus[i] * np.dot(t, np.dot(Y.T, vs[:, i]))
275+
// Y[:, i] = vs[:, i]
276+
// t[:i, i] = z[:i]
277+
// t[i, i] = -taus[i]
278+
StatusOr<XlaOp> CompactWYRepresentation(PrimitiveType type,
277279
absl::Span<const int64> batch_dims,
278280
XlaOp vs, XlaOp taus, int64 m, int64 n,
279281
PrecisionConfig::Precision precision) {
280282
std::vector<int64> batch_dim_indices(batch_dims.size());
281283
std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
284+
int64 m_index = batch_dims.size();
282285
int64 n_index = batch_dims.size() + 1;
283286

284287
auto body_fn = [&](XlaOp j, absl::Span<const XlaOp> values,
285288
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
286289
// w has shape [..., m, n]
287-
auto w = values[0];
290+
auto t = values[0];
288291
const auto vs = values[1];
289292
const auto taus = values[2];
290293

@@ -303,31 +306,37 @@ StatusOr<XlaOp> ComputeWYRepresentation(PrimitiveType type,
303306
auto y = Select(Ge(iota_mn, j), ZerosLike(vs), vs);
304307

305308
// yv has shape [..., n, 1]
306-
auto yv = BatchDot(y, true, v, false, precision);
307-
// wyv has shape [..., m, 1]
308-
auto wyv = BatchDot(w, yv, precision);
309+
auto yv =
310+
BatchDot(y, /*transpose_x=*/true, v, /*transpose_y=*/false, precision);
311+
// wyv has shape [..., n, 1]
312+
auto wyv = BatchDot(t, yv, precision);
309313

310314
auto z = Mul(
311-
-beta, v + wyv,
315+
-beta, wyv,
312316
/*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index}));
317+
beta = BroadcastInDim(beta, ConcatVectors(batch_dims, {n, 1}),
318+
ConcatVectors(batch_dim_indices, {n_index}));
319+
auto iota_n = Iota(
320+
builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {n, 1})),
321+
m_index);
313322

314-
w = DynamicUpdateSliceInMinorDims(w, z, {j});
323+
z = Select(Lt(iota_n, j), z, Select(Eq(iota_n, j), -beta, ZerosLike(beta)));
315324

316-
return std::vector<XlaOp>{w, vs, taus};
325+
t = DynamicUpdateSliceInMinorDims(t, z, {j});
326+
327+
return std::vector<XlaOp>{t, vs, taus};
317328
};
318329

319330
XlaBuilder* builder = vs.builder();
320-
auto w = Zeros(builder,
321-
ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n})));
322-
auto v = SliceInMinorDims(vs, {0}, {1});
331+
auto t = Zeros(builder,
332+
ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n, n})));
323333
auto beta = SliceInMinorDims(taus, {0}, {1});
324-
auto bv =
325-
Mul(-beta, v,
326-
/*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index}));
327-
w = UpdateSliceInMinorDims(w, bv, {0});
334+
beta = BroadcastInDim(beta, ConcatVectors(batch_dims, {1, 1}),
335+
ConcatVectors(batch_dim_indices, {n_index}));
336+
t = UpdateSliceInMinorDims(t, -beta, {0});
328337

329338
TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(n - 1, S32, body_fn,
330-
{w, vs, taus}, "wy", builder));
339+
{t, vs, taus}, "wy", builder));
331340
return values[0];
332341
}
333342

@@ -342,12 +351,10 @@ StatusOr<XlaOp> ComputeWYRepresentation(PrimitiveType type,
342351
// k = min(block_size, min(m, n) - s)
343352
// (a, vs, taus) = qr(a[i:, i:i+k])
344353
// y = vs
345-
// w = ComputeWYRepresentation(vs, taus, m-i, k)
346-
// a[i:, i+r:] += np.dot(y, np.dot(w.T, a[i:, i+k:]))
347-
// q[:, i:] += np.dot(q[:, i:], np.dot(w, y.T))
354+
// t = CompactWYRepresentation(vs, taus, m-i, k)
355+
// a[i:, i+k:] += (y @ t.T) @ (y.T @ a[i:, i+k:])
356+
// q[:, i:] += (q[:, i:] @ y) @ (y @ t.T).T
348357
// return (q, a)
349-
// TODO(phawkins): consider using UT transformations (in the form I - V U V')
350-
// rather than WY transformations.
351358
StatusOr<QRDecompositionResult> QRDecomposition(
352359
XlaOp a, bool full_matrices, int64 block_size,
353360
PrecisionConfig::Precision precision) {
@@ -384,24 +391,28 @@ StatusOr<QRDecompositionResult> QRDecomposition(
384391

385392
a = UpdateSliceInMinorDims(a, qr_block.r, {i, i});
386393

387-
// Compute the I-WY block representation of a product of Householder
388-
// matrices.
394+
// Compute the I + Y @ T @ Y^t block representation of a product of
395+
// Householder matrices.
389396
TF_ASSIGN_OR_RETURN(
390-
auto w, ComputeWYRepresentation(type, batch_dims, qr_block.vs,
397+
auto t, CompactWYRepresentation(type, batch_dims, qr_block.vs,
391398
qr_block.taus, m - i, k, precision));
392399
auto y = qr_block.vs;
393400

394-
// a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:]))
401+
// a[i:, i+k:] += (y @ t.T) @ (y.T @ a[i:, i+k:])
402+
auto yt =
403+
BatchDot(y, /*transpose_x=*/false, t, /*transpose_y=*/true, precision);
395404
auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n});
396-
auto a_update = BatchDot(w, true, a_panel, false, precision);
397-
a_update = BatchDot(y, a_update, precision);
405+
auto a_update = BatchDot(y, /*transpose_x=*/true, a_panel,
406+
/*transpose_y=*/false, precision);
407+
a_update = BatchDot(yt, a_update, precision);
398408
a_panel = a_panel + a_update;
399409
a = UpdateSliceInMinorDims(a, a_panel, {i, i + k});
400410

401-
// q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T))
411+
// q[:, i:] += (q[:, i:] @ y) @ (y @ t.T).T
402412
auto q_panel = SliceInMinorDims(q, {0, i}, {m, m});
403-
auto q_update = BatchDot(q_panel, w, precision);
404-
q_update = BatchDot(q_update, false, y, true, precision);
413+
auto q_update = BatchDot(q_panel, y, precision);
414+
q_update = BatchDot(q_update, /*transpose_x=*/false, yt,
415+
/*transpose_y=*/true, precision);
405416
q_panel = q_panel + q_update;
406417
q = UpdateSliceInMinorDims(q, q_panel, {0, i});
407418
}

0 commit comments

Comments
 (0)