Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trying to implement gradient for SVD of low-rank matrix #2329

Closed
Jakob-Unfried opened this issue Feb 28, 2020 · 2 comments
Closed

Trying to implement gradient for SVD of low-rank matrix #2329

Jakob-Unfried opened this issue Feb 28, 2020 · 2 comments

Comments

@Jakob-Unfried
Copy link
Contributor

I am trying to implement a function that performs a "reduced" SVD, i.e. for a (m, n) matrix of rank k < min(m,n), only the first k singular values are nonzero and thus only the first k columns of U and the first k rows of Vh are relevant.

the code for the forward pass is

@custom_transforms
def svd_reduced(X: array):
    assert len(X.shape) == 2

    U, S, Vh = np.linalg.svd(X, full_matrices=False)
    k = np.sum(S > 0.).astype(int)
    U = dynamic_slice_in_dim(U, 0, k, axis=1)  # U[:,:k]
    S = dynamic_slice_in_dim(S, 0, k, axis=0)  # S[:k]
    Vh = dynamic_slice_in_dim(Vh, 0, k, axis=0)  # Vh[:k,:]

    return U, S, Vh

I worked out the VJP formula and implemented it (not pasting it all, its rather long)
defvjp(svd_reduced, _vjp_rule_svd)

However I run into a
TypeError: Abstract value passed to int, which requires a concrete value. Try using value.astype(int) instead.
This is because k, which is the rank of the matrix X depends on its values and thus is not available during tracing.

I realise that this means that I will not be able to jit this function, but IIUC it should in principle be possible to calculate its grad and that is all I want.

How would I go about defining this reduced_svd and its vjp_rule ?

@shoyer
Copy link
Collaborator

shoyer commented Feb 28, 2020

You can't do Python control flow inside custom_transforms currently, but this should be possible after the rewrite of custom transforms lands: #2026

@Jakob-Unfried
Copy link
Contributor Author

That looks like it will be awesome! Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants