Skip to content

Commit

Permalink
Fixed the regularization for BGU. (halide#7684)
Browse files Browse the repository at this point in the history
Co-authored-by: Steven Johnson <srj@google.com>
  • Loading branch information
2 people authored and ardier committed Mar 3, 2024
1 parent 8be86f1 commit 56cd44e
Showing 1 changed file with 14 additions and 33 deletions.
47 changes: 14 additions & 33 deletions apps/bgu/bgu_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,39 +395,20 @@ class BGU : public Generator<BGU> {
b(2, 2) = blurx(x, y, z, 20);
b(3, 2) = blurx(x, y, z, 21);

// Regularize by pushing the solution towards the average gain
// in this cell = (average output luma + eps) / (average input luma + eps).
const float lambda = 1e-6f;
const float epsilon = 1e-6f;

// The bottom right entry of A is a count of the number of
// constraints affecting this cell.
Expr N = A(3, 3);

// The last row of each matrix is the sum of input and output
// RGB values for the pixels affecting this cell. Instead of
// dividing them by N+1 to get averages, we'll multiply
// epsilon by N+1. This saves two divisions.
Expr output_luma = b(3, 0) + 2 * b(3, 1) + b(3, 2) + epsilon * (N + 1);
Expr input_luma = A(3, 0) + 2 * A(3, 1) + A(3, 2) + epsilon * (N + 1);
Expr gain = output_luma / input_luma;

// Add lambda and lambda*gain to the diagonal of the
// matrices. The matrices are sums/moments rather than
// means/covariances, so just like above we need to multiply
// lambda by N+1 so that it's equivalent to adding a constant
// to the diagonal of a covariance matrix. Otherwise it does
// nothing in cells with lots of linearly-dependent
// constraints.
Expr weighted_lambda = lambda * (N + 1);
A(0, 0) += weighted_lambda;
A(1, 1) += weighted_lambda;
A(2, 2) += weighted_lambda;
A(3, 3) += weighted_lambda;

b(0, 0) += weighted_lambda * gain;
b(1, 1) += weighted_lambda * gain;
b(2, 2) += weighted_lambda * gain;
// Regularize it with 1/10th of a sample that pulls the result towards the identity function.
// Regions in the grid that are well populated will have way more than 1/10th of a sample.
// The original paper on BGU had a more complex regularization scheme, but the regularization
// logic was backwards: when a cell has fewer samples, it got less regularization; and when
// the cell has a lot of samples, it got regularized a lot.
const float lambda = 1e-1f;
A(0, 0) += lambda;
A(1, 1) += lambda;
A(2, 2) += lambda;
A(3, 3) += lambda;

b(0, 0) += lambda;
b(1, 1) += lambda;
b(2, 2) += lambda;

// Now solve Ax = b
Matrix<3, 4> result = transpose(solve_symmetric(A, b, line, x, using_autoscheduler(), get_target()));
Expand Down

0 comments on commit 56cd44e

Please sign in to comment.