@@ -258,33 +258,36 @@ StatusOr<QRBlockResult> QRBlock(XlaOp a, PrecisionConfig::Precision precision) {
258258 return result;
259259}
260260
261- // Computes W and Y such that I-WY is equivalent to the sequence of Householder
262- // transformations given by vs and taus.
263- // Golub and van Loan, "Matrix Computations", algorithm 5.1.2.
261+ // Computes T such that (I - Y @ T @ Y^t) is a product of the elementary
262+ // Householder reflectors given by `vs` and `taus`.
263+ //
264+ // Schreiber, Robert, and Charles Van Loan. "A storage-efficient WY
265+ // representation for products of Householder transformations." SIAM Journal on
266+ // Scientific and Statistical Computing 10.1 (1989): 53-57.
267+ //
268+ // m, n = vs.shape[-2:]
269+ // t = np.zeros((n, n))
264270// Y = np.zeros([m, n])
265- // W = np.zeros([m, n])
271+ // t[0, 0] = -taus[0]
266272// Y[:, 0] = vs[:, 0]
267- // W[:, 0] = -taus[0] * vs[:, 0]
268- // for j in xrange(1, n):
269- // v = vs[:, j]
270- // z = -taus[j] * v - taus[j] * np.dot(W, np.dot(Y.T, v))
271- // W[:, j] = z
272- // Y[:, j] = v
273- // return W
274- // There is no need to return Y since at termination of the loop it is equal to
275- // vs.
276- StatusOr<XlaOp> ComputeWYRepresentation (PrimitiveType type,
273+ // for i in range(1, n):
274+ // z = -taus[i] * np.dot(t, np.dot(Y.T, vs[:, i]))
275+ // Y[:, i] = vs[:, i]
276+ // t[:i, i] = z[:i]
277+ // t[i, i] = -taus[i]
278+ StatusOr<XlaOp> CompactWYRepresentation (PrimitiveType type,
277279 absl::Span<const int64> batch_dims,
278280 XlaOp vs, XlaOp taus, int64 m, int64 n,
279281 PrecisionConfig::Precision precision) {
280282 std::vector<int64> batch_dim_indices (batch_dims.size ());
281283 std::iota (batch_dim_indices.begin (), batch_dim_indices.end (), 0 );
284+ int64 m_index = batch_dims.size ();
282285 int64 n_index = batch_dims.size () + 1 ;
283286
284287 auto body_fn = [&](XlaOp j, absl::Span<const XlaOp> values,
285288 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
286289 // w has shape [..., m, n]
287- auto w = values[0 ];
290+ auto t = values[0 ];
288291 const auto vs = values[1 ];
289292 const auto taus = values[2 ];
290293
@@ -303,31 +306,37 @@ StatusOr<XlaOp> ComputeWYRepresentation(PrimitiveType type,
303306 auto y = Select (Ge (iota_mn, j), ZerosLike (vs), vs);
304307
305308 // yv has shape [..., n, 1]
306- auto yv = BatchDot (y, true , v, false , precision);
307- // wyv has shape [..., m, 1]
308- auto wyv = BatchDot (w, yv, precision);
309+ auto yv =
310+ BatchDot (y, /* transpose_x=*/ true , v, /* transpose_y=*/ false , precision);
311+ // wyv has shape [..., n, 1]
312+ auto wyv = BatchDot (t, yv, precision);
309313
310314 auto z = Mul (
311- -beta, v + wyv,
315+ -beta, wyv,
312316 /* broadcast_dimensions=*/ ConcatVectors (batch_dim_indices, {n_index}));
317+ beta = BroadcastInDim (beta, ConcatVectors (batch_dims, {n, 1 }),
318+ ConcatVectors (batch_dim_indices, {n_index}));
319+ auto iota_n = Iota (
320+ builder, ShapeUtil::MakeShape (S32, ConcatVectors (batch_dims, {n, 1 })),
321+ m_index);
313322
314- w = DynamicUpdateSliceInMinorDims (w, z, {j} );
323+ z = Select ( Lt (iota_n, j), z, Select ( Eq (iota_n, j), -beta, ZerosLike (beta)) );
315324
316- return std::vector<XlaOp>{w, vs, taus};
325+ t = DynamicUpdateSliceInMinorDims (t, z, {j});
326+
327+ return std::vector<XlaOp>{t, vs, taus};
317328 };
318329
319330 XlaBuilder* builder = vs.builder ();
320- auto w = Zeros (builder,
321- ShapeUtil::MakeShape (type, ConcatVectors (batch_dims, {m, n})));
322- auto v = SliceInMinorDims (vs, {0 }, {1 });
331+ auto t = Zeros (builder,
332+ ShapeUtil::MakeShape (type, ConcatVectors (batch_dims, {n, n})));
323333 auto beta = SliceInMinorDims (taus, {0 }, {1 });
324- auto bv =
325- Mul (-beta, v,
326- /* broadcast_dimensions=*/ ConcatVectors (batch_dim_indices, {n_index}));
327- w = UpdateSliceInMinorDims (w, bv, {0 });
334+ beta = BroadcastInDim (beta, ConcatVectors (batch_dims, {1 , 1 }),
335+ ConcatVectors (batch_dim_indices, {n_index}));
336+ t = UpdateSliceInMinorDims (t, -beta, {0 });
328337
329338 TF_ASSIGN_OR_RETURN (auto values, ForEachIndex (n - 1 , S32, body_fn,
330- {w , vs, taus}, " wy" , builder));
339+ {t , vs, taus}, " wy" , builder));
331340 return values[0 ];
332341}
333342
@@ -342,12 +351,10 @@ StatusOr<XlaOp> ComputeWYRepresentation(PrimitiveType type,
342351// k = min(block_size, min(m, n) - s)
343352// (a, vs, taus) = qr(a[i:, i:i+k])
344353// y = vs
345- // w = ComputeWYRepresentation (vs, taus, m-i, k)
346- // a[i:, i+r :] += np.dot(y, np.dot(w.T, a[i:, i+k:]) )
347- // q[:, i:] += np.dot (q[:, i:], np.dot(w, y .T))
354+ // t = CompactWYRepresentation (vs, taus, m-i, k)
355+ // a[i:, i+k :] += (y @ t.T) @ (y.T @ a[i:, i+k:])
356+ // q[:, i:] += (q[:, i:] @ y) @ (y @ t .T).T
348357// return (q, a)
349- // TODO(phawkins): consider using UT transformations (in the form I - V U V')
350- // rather than WY transformations.
351358StatusOr<QRDecompositionResult> QRDecomposition (
352359 XlaOp a, bool full_matrices, int64 block_size,
353360 PrecisionConfig::Precision precision) {
@@ -384,24 +391,28 @@ StatusOr<QRDecompositionResult> QRDecomposition(
384391
385392 a = UpdateSliceInMinorDims (a, qr_block.r , {i, i});
386393
387- // Compute the I-WY block representation of a product of Householder
388- // matrices.
394+ // Compute the I + Y @ T @ Y^t block representation of a product of
395+ // Householder matrices.
389396 TF_ASSIGN_OR_RETURN (
390- auto w, ComputeWYRepresentation (type, batch_dims, qr_block.vs ,
397+ auto t, CompactWYRepresentation (type, batch_dims, qr_block.vs ,
391398 qr_block.taus , m - i, k, precision));
392399 auto y = qr_block.vs ;
393400
394- // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:]))
401+ // a[i:, i+k:] += (y @ t.T) @ (y.T @ a[i:, i+k:])
402+ auto yt =
403+ BatchDot (y, /* transpose_x=*/ false , t, /* transpose_y=*/ true , precision);
395404 auto a_panel = SliceInMinorDims (a, {i, i + k}, {m, n});
396- auto a_update = BatchDot (w, true , a_panel, false , precision);
397- a_update = BatchDot (y, a_update, precision);
405+ auto a_update = BatchDot (y, /* transpose_x=*/ true , a_panel,
406+ /* transpose_y=*/ false , precision);
407+ a_update = BatchDot (yt, a_update, precision);
398408 a_panel = a_panel + a_update;
399409 a = UpdateSliceInMinorDims (a, a_panel, {i, i + k});
400410
401- // q[:, i:] += np.dot(np.dot( q[:, i:], W), Y .T))
411+ // q[:, i:] += ( q[:, i:] @ y) @ (y @ t .T).T
402412 auto q_panel = SliceInMinorDims (q, {0 , i}, {m, m});
403- auto q_update = BatchDot (q_panel, w, precision);
404- q_update = BatchDot (q_update, false , y, true , precision);
413+ auto q_update = BatchDot (q_panel, y, precision);
414+ q_update = BatchDot (q_update, /* transpose_x=*/ false , yt,
415+ /* transpose_y=*/ true , precision);
405416 q_panel = q_panel + q_update;
406417 q = UpdateSliceInMinorDims (q, q_panel, {0 , i});
407418 }
0 commit comments