Skip to content

Commit

Permalink
Add support for dealiasing Adjoint and Transpose
Browse files Browse the repository at this point in the history
And fixup the assumption that `broadcast!(f, C, A)` is safe
  • Loading branch information
mbauman committed Feb 6, 2018
1 parent b2da7b4 commit 2301129
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 1 deletion.
5 changes: 4 additions & 1 deletion base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -484,10 +484,13 @@ end
return C
end

# In the one-argument case, we don't need to worry about aliasing as we only make one pass
# In the one-argument case, we can avoid de-aliasing `A` from `C` if
# `A === C`. Otherwise `A` might be something like `transpose(C)` or
# another such re-ordering that won't iterate the two safely.
@inline function _broadcast!(f, C, A)
shape = broadcast_indices(C)
@boundscheck check_broadcast_indices(shape, A)
A !== C && (A = unalias(C, A))
keeps, Idefaults = map_newindexer(shape, A, ())
iter = CartesianIndices(shape)
_broadcast!(f, C, keeps, Idefaults, A, (), Val(0), iter)
Expand Down
4 changes: 4 additions & 0 deletions stdlib/LinearAlgebra/src/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ end
Adjoint(A) = Adjoint{Base.promote_op(adjoint,eltype(A)),typeof(A)}(A)
Transpose(A) = Transpose{Base.promote_op(transpose,eltype(A)),typeof(A)}(A)

using Base: unalias, mightalias, dataids
Base.unalias(dest, A::Union{Adjoint,Transpose}) = mightalias(dest, A) ? typeof(A)(unalias(dest, A.parent)) : A
Base.dataids(A::Union{Adjoint,Transpose}) = dataids(A.parent)

# wrapping lowercase quasi-constructors
"""
adjoint(A)
Expand Down
13 changes: 13 additions & 0 deletions stdlib/LinearAlgebra/test/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -447,4 +447,17 @@ end
@test adjoint!(b, a) === b
end

@testset "aliasing with adjoint and transpose" begin
A = collect(reshape(1:25, 5, 5)) .+ rand.().*im
B = copy(A)
B .= B'
@test B == A'
B = copy(A)
B .= transpose(B)
@test B == transpose(A)
B = copy(A)
B .= B .* B'
@test B == A .* A'
end

end # module TestAdjointTranspose

0 comments on commit 2301129

Please sign in to comment.