Skip to content

Commit

Permalink
Overload LinearAlgebra.tr (#210)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch authored Jul 6, 2023
1 parent b6b1044 commit 65eb422
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "LinearMaps"
uuid = "7a12625a-238d-50fd-b39a-03d52299707e"
version = "3.10.2"
version = "3.11.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
7 changes: 7 additions & 0 deletions docs/src/history.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Version history

## What's new in v3.11

* The `tr` function from `LinearAlgebra.jl` is now overloaded both for generic `LinearMap`
types and specialized for most provided `LinearMap` types. In the generic case, this is
computationally as expensive as computing the whole matrix representation, though the
latter is, of course, not stored.

## What's new in v3.10

* A new `MulStyle` trait called `TwoArg` has been added. It should be used for `LinearMap`s
Expand Down
13 changes: 13 additions & 0 deletions docs/src/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ as in the usual matrix case: `transpose(A) * x` and `mul!(y, A', x)`, for instan
a linear map for which you only have a function definition (e.g. to be able
to use its `transpose` or `adjoint`).

!!! note
In Julia versions v1.9 and higher, conversion to sparse matrices requires loading
`SparseArrays.jl` by the user in advance.

### Slicing methods

Complete slicing, i.e., `A[:,j]`, `A[:,J]`, `A[i,:]`, `A[I,:]` and `A[:,:]` for `i`, `j`
Expand All @@ -188,3 +192,12 @@ slicing) to standard unit vectors of appropriate length. By complete slicing we
two-dimensional Cartesian indexing where at least one of the "indices" is a colon. This is
facilitated by overloads of `Base.getindex`. Partial slicing à la `A[I,J]` and scalar or
linear indexing are _not_ supported.

### Sum, product, mean and trace

Natural function overloads for `Base.sum`, `Base.prod`, `Statistics.mean` and `LinearAlgebra.tr`
exist.

!!! note
In Julia versions v1.9 and higher, creating the mean linear operator requires loading
`Statistics.jl` by the user in advance.
3 changes: 2 additions & 1 deletion src/LinearMaps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ export ⊗, squarekron, kronsum, ⊕, sumkronsum, khatrirao, facesplitting

using LinearAlgebra
using LinearAlgebra: AbstractQ
import LinearAlgebra: mul!
import LinearAlgebra: mul!, tr

using Base: require_one_based_indexing

Expand Down Expand Up @@ -348,6 +348,7 @@ include("conversion.jl") # conversion of linear maps to matrices
include("show.jl") # show methods for LinearMap objects
include("getindex.jl") # getindex functionality
include("inversemap.jl")
include("trace.jl")

"""
LinearMap(A::LinearMap; kwargs...)::WrappedMap
Expand Down
4 changes: 2 additions & 2 deletions src/kronecker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ compared to [`kron`](@ref), but benchmarking intended use cases is highly recomm
function squarekron(A::MapOrMatrix, B::MapOrMatrix, C::MapOrMatrix, Ds::MapOrMatrix...)
maps = (A, B, C, Ds...)
T = promote_type(map(eltype, maps)...)
all(_issquare, maps) || throw(ArgumentError("operators need to be square in Kronecker sums"))
all(_issquare, maps) || throw(ArgumentError("operators need to be square in squarekron"))
ns = map(a -> size(a, 1), maps)
firstmap = first(maps) UniformScalingMap(true, prod(ns[2:end]))
lastmap = UniformScalingMap(true, prod(ns[1:end-1])) last(maps)
Expand Down Expand Up @@ -376,7 +376,7 @@ true
[^1]: Fernandes, P. and Plateau, B. and Stewart, W. J. ["Efficient Descriptor-Vector Multiplications in Stochastic Automata Networks"](https://doi.org/10.1145/278298.278303), _Journal of the ACM_, 45(3), 381–414, 1998.
"""
function sumkronsum(A::MapOrMatrix, B::MapOrMatrix)
LinearAlgebra.checksquare(A, B)
(_issquare(A) && _issquare(B)) || throw(ArgumentError("operators need to be square in Kronecker sums"))
A UniformScalingMap(true, size(B,1)) + UniformScalingMap(true, size(A,1)) B
end
function sumkronsum(A::MapOrMatrix, B::MapOrMatrix, C::MapOrMatrix, Ds::MapOrMatrix...)
Expand Down
62 changes: 62 additions & 0 deletions src/trace.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
function tr(A::LinearMap)
_issquare(A) || throw(ArgumentError("operator needs to be square in tr"))
_tr(A)
end

function _tr(A::LinearMap{T}) where {T}
S = typeof(oneunit(eltype(A)) + oneunit(eltype(A)))
ax1, ax2 = axes(A)
xi = zeros(eltype(A), ax2)
y = similar(xi, T, ax1)
o = one(T)
z = zero(T)
s = zero(S)
@inbounds for (i, j) in zip(ax1, ax2)
xi[j] = o
mul!(y, A, xi)
xi[j] = z
s += y[i]
end
return s
end
function _tr(A::OOPFunctionMap{T}) where {T}
S = typeof(oneunit(eltype(A)) + oneunit(eltype(A)))
ax1, ax2 = axes(A)
xi = zeros(eltype(A), ax2)
o = one(T)
z = zero(T)
s = zero(S)
@inbounds for (i, j) in zip(ax1, ax2)
xi[j] = o
s += (A * xi)[i]
xi[j] = z
end
return s
end
# specialiations
_tr(A::AbstractVecOrMat) = tr(A)
_tr(A::WrappedMap) = _tr(A.lmap)
_tr(A::TransposeMap) = _tr(A.lmap)
_tr(A::AdjointMap) = conj(_tr(A.lmap))
_tr(A::UniformScalingMap) = A.M * A.λ
_tr(A::ScaledMap) = A.λ * _tr(A.lmap)
function _tr(L::KroneckerMap)
if all(_issquare, L.maps)
return prod(_tr, L.maps)
else
return invoke(_tr, Tuple{LinearMap}, L)
end
end
function _tr(L::OuterProductMap{<:RealOrComplex})
a, bt = L.maps
return bt.lmap*a.lmap
end
function _tr(L::OuterProductMap)
a, bt = L.maps
mapreduce(*, +, a.lmap, bt.lmap)
end
function _tr(L::KroneckerSumMap)
A, B = L.maps # A and B are square by construction
return _tr(A) * size(B, 1) + _tr(B) * size(A, 1)
end
_tr(A::FillMap) = A.size[1] * A.λ
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,5 @@ include("inversemap.jl")
include("rrules.jl")

include("khatrirao.jl")

include("trace.jl")
28 changes: 28 additions & 0 deletions test/trace.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
using LinearMaps, LinearAlgebra, Test

@testset "trace" begin
for A in (randn(5, 5), randn(ComplexF64, 5, 5))
@test tr(LinearMap(A)) == tr(A)
@test tr(transpose(LinearMap(A))) == tr(A)
@test tr(adjoint(LinearMap(A))) == tr(A')
end
@test tr(LinearMap(3I, 10)) == 30
@test tr(LinearMap{Int}(cumsum, 10)) == 10
@test tr(LinearMap{Int}(cumsum, reversecumsumreverse, 10)') == 10
@test tr(LinearMap{Complex{Int}}(cumsum, reversecumsumreverse, 10)') == 10
@test tr(LinearMap{Int}(cumsum!, 10)) == 10
@test tr(2LinearMap{Int}(cumsum!, 10)) == 20
A = randn(3, 5); B = copy(transpose(A))
@test tr(A B) == tr(kron(A, B))
@test tr(A B A B) tr(kron(A, B, A, B))
A = randn(5, 5); B = copy(transpose(A))
@test tr(A B) tr(kron(A, B))
@test tr(A B A) tr(kron(A, B, A))
@test tr(A B A B) tr(kron(A, B, A, B))
v = A[:,1]
@test tr(v v') norm(v)^2
v = [randn(2,2) for _ in 1:3]
@test tr(v v') mapreduce(*, +, v, v')
@test tr(LinearMap{Int}(cumsum!, 10) LinearMap{Int}(cumsum!, 10)) == 200
@test tr(FillMap(true, 5, 5)) == 5
end

2 comments on commit 65eb422

@dkarrasch
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/88218

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v3.11.0 -m "<description of version>" 65eb42238b3435d553d6cedf4a2fbb312b8bf4fc
git push origin v3.11.0

Please sign in to comment.