-
Notifications
You must be signed in to change notification settings - Fork 90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP matmul on arrays of arrays #325
base: main
Are you sure you want to change the base?
Conversation
In general we really don't consistently enforce that cotangents of reals are real if they interact with a complex number, i.e. in most places we do the Zygote thing. There are a few places where I've explicitly added projections to real for new or improved rules, following JuliaDiff/ChainRulesCore.jl#176, e.g. the rules for If we choose to go this way, which I think we should at least for real/complex, then it should probably be handled in its own PR. I haven't looked at this PR in detail yet, but I think the easy thing to do right now is just test all real and all complex. |
OK, I can drop that test, for now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All conjugation looks correct. I made some minor suggestions, and it looks like a few tests are missing. Once tests pass and coverage is good, feel free to merge.
B::AbstractMatrix{<:MatMulField}, | ||
) | ||
function times_pullback(Ȳ) | ||
@assert size(B, 1) === 1 # otherwise primal would have failed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this necessary? Because the primal already did this check.
function rrule( | ||
::typeof(*), A::CommutativeMulNumber, B::AbstractArray{<:CommutativeMulNumber} | ||
::typeof(*), A::CommutativeMulNumber, B::AbstractArray{<:MatMulField} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be notationally clearer to make A
lower case since it's a scalar?
) | ||
function times_pullback(Ȳ) | ||
return ( | ||
NO_FIELDS, | ||
@thunk(dot(Ȳ, B)'), | ||
(dot(Ȳ, B)'), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Swapping the order in dot
conjugates the output, e.g.
julia> x, y = randn(ComplexF64, 3), randn(ComplexF64, 3);
julia> dot(x, y) == dot(y, x)'
true
Could make the same change in the below rule
(dot(Ȳ, B)'), | |
dot(B, Ȳ), |
A::AbstractVecOrMat{<:MatMulField}, | ||
B::AbstractVecOrMat{<:MatMulField}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a test for this rule?
function rrule( | ||
::typeof(*), | ||
A::AbstractVector{<:MatMulField}, | ||
B::AbstractMatrix{<:MatMulField}, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a test for this rule?
Closes FluxML/Zygote.jl#837
Also needed to add a special case for not inplacing as one is broken
ttps://github.com/JuliaLang/julia/issues/38772
Also the CoVector one deems to not work.
But I ran into issues. Seems like we were not testing `*(::Real, ::Array{Compplex})` which gives an incorrect anser for the first argument. At least I sassume it is incorrect since it is a complex number.It would be super useful if someone who really understand the complex number stuff (@sethaxen, @simeonschaub) are able to take over this PR, at least as far as getting the existing complex number tests passing.
(moved to sperate isssue)