Skip to content

Commit

Permalink
Multiply BlockDiagonal with JuMP AffExpr
Browse files Browse the repository at this point in the history
  • Loading branch information
fchorney committed May 11, 2022
1 parent 0b98861 commit 2122093
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 2 deletions.
9 changes: 7 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,26 +1,31 @@
name = "BlockDiagonals"
uuid = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
authors = ["Invenia Technical Computing Corporation"]
version = "0.1.26"
version = "0.1.27"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[compat]
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
FillArrays = "0.6, 0.7, 0.8, 0.9, 0.10, 0.11, 0.12, 0.13"
FiniteDifferences = "0.12.3"
JuMP = "1"
julia = "1"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["ChainRulesTestUtils", "Documenter", "Random", "Test"]
test = ["ChainRulesTestUtils", "Dates", "Distributions", "Documenter", "Random", "Test"]
2 changes: 2 additions & 0 deletions src/BlockDiagonals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Base: @propagate_inbounds
using ChainRulesCore
using FillArrays: Zeros
using FiniteDifferences
using JuMP: AffExpr
using LinearAlgebra

import ChainRulesCore.ProjectTo
Expand All @@ -15,5 +16,6 @@ include("blockdiagonal.jl")
include("base_maths.jl")
include("chainrules.jl")
include("linalg.jl")
include("jump.jl")

end # end module
9 changes: 9 additions & 0 deletions src/jump.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
function Base.:*(A::BlockDiagonal, x::Vector{T}) where {T<:AffExpr}
Multiply a `BlockDiagonal` with a `Vector{AffExpr}` from JuMP so we don't need to convert
the `BlockDiagonal` to a `Matrix` first.
"""
function Base.:*(A::BlockDiagonal, x::Vector{T}) where {T<:AffExpr}
return mul!(similar(x, T, axes(A, 1)), A, x)
end
34 changes: 34 additions & 0 deletions test/jump.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
@testset "JuMP" begin
num_nodes = 2
num_targets = 2
nodes = [randstring(3) for _ in 1:num_nodes]
targets = [DateTime(2020, 1, 1, h) for h in 1:num_targets]

dists = Vector{MvNormal}()
for k in targets
mu = randn(num_nodes * num_targets)
X = rand(num_nodes * num_targets, num_nodes * num_targets)
sigma = X * X' + I
push!(dists, MvNormal(mu, sigma))
end

covs = [Matrix(cov(d)) for d in dists]
means = [mean(d) for d in dists]

preds = (mean=vcat(means...), cov=BlockDiagonal(covs), target=targets, nodes=nodes)

@testset "Multiplication" begin
model = JuMP.Model()

n = length(preds.mean)
v = (
supply_mwh=@variable(model, supply_mwh[1:n] >= 0),
demand_mwh=@variable(model, demand_mwh[1:n] <= 0),
)

volume = v.supply_mwh + v.demand_mwh
normalized_sqrt_cov = cholesky(preds.cov).U / 24

@test normalized_sqrt_cov * volume == Matrix(normalized_sqrt_cov) * volume
end
end
7 changes: 7 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
using BlockDiagonals
using ChainRulesCore
using ChainRulesTestUtils
using Dates
using Distributions
using Documenter
using FiniteDifferences # For overloading to_vec
using JuMP
using Random
using Test
using LinearAlgebra

Random.seed!(42069)

push!(ChainRulesTestUtils.TRANSFORMS_TO_ALT_TANGENTS, x -> @thunk(x))

@testset "BlockDiagonals" begin
Expand All @@ -15,4 +21,5 @@ push!(ChainRulesTestUtils.TRANSFORMS_TO_ALT_TANGENTS, x -> @thunk(x))
include("base_maths.jl")
include("chainrules.jl")
include("linalg.jl")
include("jump.jl")
end # tests

0 comments on commit 2122093

Please sign in to comment.