You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It can be significantly more efficient, both in runtime and memory, to directly define derivative rules for higher level linear algebra operations like solve rather than the constituent operations (factorization and triangular solve).
For large matrices (e.g., 500x500 on CPUs), my microbenchmark shows that we can get a 3-4x speed-up for general purpose and symmetric solves:
fromfunctoolsimportpartialimportjax.scipyasjspfromjaximportlaximportjax.numpyasnpimportnumpyasonpimportjaxdefpositive_definite_solve(a, b):
factors=jsp.linalg.cho_factor(a)
defsolve(matvec, x):
returnjsp.linalg.cho_solve(factors, x)
matvec=partial(np.dot, a)
returnlax.custom_linear_solve(matvec, b, solve, symmetric=True)
deflinear_solve(a, b):
a_factors=jsp.linalg.lu_factor(a)
defsolve(matvec, x):
returnjsp.linalg.lu_solve(a_factors, x)
deftranspose_solve(vecmat, x):
returnjsp.linalg.lu_solve(a_factors, x, trans=1)
matvec=partial(np.dot, a)
returnlax.custom_linear_solve(matvec, b, solve, transpose_solve)
defloss(solve):
deff(a, b):
returnsolve(a, b).sum()
returnfrs=onp.random.RandomState(0)
a=rs.randn(500, 500)
a=jax.device_put(a.T @ a+0.1*np.eye(500))
b=jax.device_put(rs.randn(500))
# general purpose solve# currentgrad=jax.jit(jax.grad(loss(np.linalg.solve)))
%timeitjax.device_get(grad(a, b))
# 33.8 ms per loop# newgrad=jax.jit(jax.grad(loss(linear_solve)))
%timeitjax.device_get(grad(a, b))
# 10.1 ms per loop# positive definite solve# currentgrad=jax.jit(jax.grad(loss(partial(jsp.linalg.solve, sym_pos=True))))
%timeitjax.device_get(grad(a, b))
# 23.7 ms per loop# newgrad=jax.jit(jax.grad(loss(positive_definite_solve)))
%timeitjax.device_get(grad(a, b))
# 4.8 ms per loop
Unfortunately, we can't just use these prototype implementations internally in JAX, for two reasons:
We do an extra optimization in triangular_solve_jvp_rule_a for the case of solving many right-hand-sides at the same time with the same left-hand side (Speedup JVP for triangular solve #1466). This new gradient rule here doesn't handle this yet. Update: in practice, I don't think this optimization actually matters -- it's the difference between n*m*m+m*m*m time vs 2*m*m*m time.
Thanks to @mattjj for pointing out that LU solve has the trans argument, which means we use a single factorization for both forward and reverse calculations to speed up solves on all types of matrices.
It can be significantly more efficient, both in runtime and memory, to directly define derivative rules for higher level linear algebra operations like
solve
rather than the constituent operations (factorization and triangular solve).For large matrices (e.g., 500x500 on CPUs), my microbenchmark shows that we can get a 3-4x speed-up for general purpose and symmetric solves:
Unfortunately, we can't just use these prototype implementations internally in JAX, for two reasons:
This was solved by Batching rule for custom_linear_solve #2099.custom_linear_solve
(like custom transforms in general) doesn't work with batching yet (custom_transforms vjp rule clobbered under vmap #1249).We do an extra optimization inUpdate: in practice, I don't think this optimization actually matters -- it's the difference betweentriangular_solve_jvp_rule_a
for the case of solving many right-hand-sides at the same time with the same left-hand side (Speedup JVP for triangular solve #1466). This new gradient rule here doesn't handle this yet.n*m*m+m*m*m
time vs2*m*m*m
time.custom_linear_solve
.The text was updated successfully, but these errors were encountered: