You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to implement a function that performs a "reduced" SVD, i.e. for a (m, n) matrix of rank k < min(m,n), only the first k singular values are nonzero and thus only the first k columns of U and the first k rows of Vh are relevant.
I worked out the VJP formula and implemented it (not pasting it all, its rather long) defvjp(svd_reduced, _vjp_rule_svd)
However I run into a TypeError: Abstract value passed to int, which requires a concrete value. Try using value.astype(int) instead.
This is because k, which is the rank of the matrix X depends on its values and thus is not available during tracing.
I realise that this means that I will not be able to jit this function, but IIUC it should in principle be possible to calculate its grad and that is all I want.
How would I go about defining this reduced_svd and its vjp_rule ?
The text was updated successfully, but these errors were encountered:
I am trying to implement a function that performs a "reduced" SVD, i.e. for a
(m, n)
matrix of rankk < min(m,n)
, only the firstk
singular values are nonzero and thus only the firstk
columns ofU
and the firstk
rows ofVh
are relevant.the code for the forward pass is
I worked out the VJP formula and implemented it (not pasting it all, its rather long)
defvjp(svd_reduced, _vjp_rule_svd)
However I run into a
TypeError: Abstract value passed to int, which requires a concrete value. Try using value.astype(int) instead.
This is because
k
, which is the rank of the matrixX
depends on its values and thus is not available during tracing.I realise that this means that I will not be able to
jit
this function, but IIUC it should in principle be possible to calculate itsgrad
and that is all I want.How would I go about defining this
reduced_svd
and itsvjp_rule
?The text was updated successfully, but these errors were encountered: