Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Work out traceable JVP and transpose rules for JAX #17

Open
dfm opened this issue Nov 9, 2020 · 1 comment
Open

Work out traceable JVP and transpose rules for JAX #17

dfm opened this issue Nov 9, 2020 · 1 comment
Labels
enhancement New feature or request

Comments

@dfm
Copy link
Member

dfm commented Nov 9, 2020

It should be possible to write the JVP ops using existing celerite primitives. This would allow support for higher order differentiation and perhaps it won't cause a significant computational overhead.

For example, the matmul_lower JVP can be implemented as follows:

def matmul_lower_jvp(arg_values, arg_tangents):
    def make_zero(x, t):
        return lax.zeros_like_array(x) if type(t) is ad.Zero else t

    t, c, U, V, Y = arg_values
    tp, cp, Up, Vp, Yp = (
        make_zero(x, t) for x, t in zip(arg_values, arg_tangents)
    )
    
    Ut = -(c[None, :] * tp[:, None] + cp[None, :] * t[:, None]) * U + Up
    Vt = (c[None, :] * tp[:, None] + cp[None, :] * t[:, None]) * V + Vp
    Zp = matmul_lower(t, c, U, V, Yp)
    Zp += matmul_lower(t, c, Ut, V, Y)
    Zp += matmul_lower(t, c, U, Vt, Y)
    
    return matmul_lower_p.bind(t, c, U, V, Y), (Zp, None)

But I haven't figured out the correct transpose yet.

@dfm dfm added the enhancement New feature or request label Nov 9, 2020
@dfm
Copy link
Member Author

dfm commented Dec 9, 2020

I figured out the transpose rules for multiplication. We need to generalize the matmul to have "propagators" on both the left and right. But in that case, if

Z = mml(t, cl, cr, U, V, Y)

then

bU = mml(t, cr, cl, bZ, Y, V)
bV = mmu(t, cr, cl, bZ, Y, U)
bY = mmu(t, cl, cr, U, V, bZ)
import numpy as np

def mml(t, cl, cr, U, V, Y):
    Z = np.empty_like(Y)
    Z[0] = 0.0
    F = np.zeros((U.shape[1], Y.shape[1]))
    for n in range(1, U.shape[0]):
        F += np.outer(V[n - 1], Y[n - 1])

        pl = np.exp(cl * (t[n - 1] - t[n]))
        pr = np.exp(cr * (t[n - 1] - t[n]))
        F = np.diag(pl) @ F @ np.diag(pr)
        
        Z[n] = U[n] @ F
    return Z

def mmu(t, cl, cr, U, V, Y):
    Z = np.empty_like(Y)
    Z[-1] = 0.0
    F = np.zeros((U.shape[1], Y.shape[1]))
    for n in range(U.shape[0] - 2, -1, -1):
        F += np.outer(U[n + 1], Y[n + 1])

        pl = np.exp(cl * (t[n] - t[n + 1]))
        pr = np.exp(cr * (t[n] - t[n + 1]))
        F = np.diag(pl) @ F @ np.diag(pr)

        Z[n] = V[n] @ F
    return Z

N = 100
J = 4
K = 3

t = np.sort(np.random.uniform(0, 10, N))
cl = np.random.rand(J)
cr = np.zeros(K)

U = np.random.randn(N, J)
V = np.random.randn(N, J)
Y = np.random.randn(N, K)

Up = np.exp(-cl[None, :] * t[:, None]) * U
Vp = np.exp(cl[None, :] * t[:, None]) * V

assert np.allclose(mml(t, cl, cr, U, V, Y), np.tril(Up @ Vp.T, -1) @ Y)
assert np.allclose(mmu(t, cl, cr, U, V, Y), np.triu(Vp @ Up.T, 1) @ Y)

bZ = ....

bY = mmu(t, cl, cr, U, V, bZ)
bU = mml(t, cr, cl, bZ, Y, V)
bV = mmu(t, cr, cl, bZ, Y, U)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant