Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP matmul on arrays of arrays #325

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

WIP matmul on arrays of arrays #325

wants to merge 1 commit into from

Conversation

oxinabox
Copy link
Member

@oxinabox oxinabox commented Dec 8, 2020

Closes FluxML/Zygote.jl#837

Also needed to add a special case for not inplacing as one is broken
ttps://github.com/JuliaLang/julia/issues/38772

Also the CoVector one deems to not work.

But I ran into issues. Seems like we were not testing `*(::Real, ::Array{Compplex})` which gives an incorrect anser for the first argument. At least I sassume it is incorrect since it is a complex number.

It would be super useful if someone who really understand the complex number stuff (@sethaxen, @simeonschaub) are able to take over this PR, at least as far as getting the existing complex number tests passing.

(moved to sperate isssue)

@sethaxen
Copy link
Member

In general we really don't consistently enforce that cotangents of reals are real if they interact with a complex number, i.e. in most places we do the Zygote thing. There are a few places where I've explicitly added projections to real for new or improved rules, following JuliaDiff/ChainRulesCore.jl#176, e.g. the rules for abs (or any scalar rule that calls _realconjtimes), the cotangents of the eigenvalues of Hermitian matrices, etc.

If we choose to go this way, which I think we should at least for real/complex, then it should probably be handled in its own PR. I haven't looked at this PR in detail yet, but I think the easy thing to do right now is just test all real and all complex.

@oxinabox
Copy link
Member Author

OK, I can drop that test, for now
Can you open an issue about that you said for enforcing?

@oxinabox oxinabox changed the title WIP matmul on arrays of arrays and fix Real*Array{<:Complex}` WIP matmul on arrays of arrays Dec 14, 2020
Copy link
Member

@sethaxen sethaxen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All conjugation looks correct. I made some minor suggestions, and it looks like a few tests are missing. Once tests pass and coverage is good, feel free to merge.

B::AbstractMatrix{<:MatMulField},
)
function times_pullback(Ȳ)
@assert size(B, 1) === 1 # otherwise primal would have failed.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? Because the primal already did this check.

function rrule(
::typeof(*), A::CommutativeMulNumber, B::AbstractArray{<:CommutativeMulNumber}
::typeof(*), A::CommutativeMulNumber, B::AbstractArray{<:MatMulField}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be notationally clearer to make A lower case since it's a scalar?

)
function times_pullback(Ȳ)
return (
NO_FIELDS,
@thunk(dot(Ȳ, B)'),
(dot(Ȳ, B)'),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Swapping the order in dot conjugates the output, e.g.

julia> x, y = randn(ComplexF64, 3), randn(ComplexF64, 3);

julia> dot(x, y) == dot(y, x)'
true

Could make the same change in the below rule

Suggested change
(dot(Ȳ, B)'),
dot(B, Ȳ),

Comment on lines +32 to +33
A::AbstractVecOrMat{<:MatMulField},
B::AbstractVecOrMat{<:MatMulField},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a test for this rule?

Comment on lines +79 to +83
function rrule(
::typeof(*),
A::AbstractVector{<:MatMulField},
B::AbstractMatrix{<:MatMulField},
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a test for this rule?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

StaticArray block vec' * mat segfault on 0.5.11
2 participants