Skip to content

Commit

Permalink
Test sparse broadcast! over combinations of broadcast scalars and spa…
Browse files Browse the repository at this point in the history
…rse vectors/matrices.
  • Loading branch information
Sacha0 committed Jan 1, 2017
1 parent 8c60d82 commit 3d9fc3c
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 12 deletions.
18 changes: 9 additions & 9 deletions base/sparse/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -867,22 +867,22 @@ end
# argument tuple (passedargstup) containing only the sparse vectors/matrices in mixedargs
# in their orginal order, and such that the result of broadcast(g, passedargstup...) is
# broadcast(f, mixedargs...)
capturescalars(f, mixedargs) =
@inline capturescalars(f, mixedargs) =
capturescalars((passed, tofill) -> f(tofill...), (), mixedargs...)
# Recursion cases for capturescalars
capturescalars(f, passedargstup, scalararg, mixedargs...) =
@inline capturescalars(f, passedargstup, scalararg, mixedargs...) =
capturescalars(capturescalar(f, scalararg), passedargstup, mixedargs...)
capturescalars(f, passedargstup, nonscalararg::SparseVecOrMat, mixedargs...) =
@inline capturescalars(f, passedargstup, nonscalararg::SparseVecOrMat, mixedargs...) =
capturescalars(passnonscalar(f), (passedargstup..., nonscalararg), mixedargs...)
passnonscalar(f) = (passed, tofill) -> f(Base.front(passed), (last(passed), tofill...))
capturescalar(f, scalararg) = (passed, tofill) -> f(passed, (scalararg, tofill...))
@inline passnonscalar(f) = (passed, tofill) -> f(Base.front(passed), (last(passed), tofill...))
@inline capturescalar(f, scalararg) = (passed, tofill) -> f(passed, (scalararg, tofill...))
# Base cases for capturescalars
capturescalars(f, passedargstup, scalararg) =
@inline capturescalars(f, passedargstup, scalararg) =
(capturelastscalar(f, scalararg), passedargstup)
capturescalars(f, passedargstup, nonscalararg::SparseVecOrMat) =
@inline capturescalars(f, passedargstup, nonscalararg::SparseVecOrMat) =
(passlastnonscalar(f), (passedargstup..., nonscalararg))
passlastnonscalar(f) = (passed...) -> f(Base.front(passed), (last(passed),))
capturelastscalar(f, scalararg) = (passed...) -> f(passed, (scalararg,))
@inline passlastnonscalar(f) = (passed...) -> f(Base.front(passed), (last(passed),))
@inline capturelastscalar(f, scalararg) = (passed...) -> f(passed, (scalararg,))

# NOTE: The following two method definitions work around #19096.
broadcast{Tf,T}(f::Tf, ::Type{T}, A::SparseMatrixCSC) = broadcast(y -> f(T, y), A)
Expand Down
36 changes: 33 additions & 3 deletions test/sparse/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ end
end
end


@testset "sparse map/broadcast with result eltype not a concrete subtype of Number (#19561/#19589)" begin
intoneorfloatzero(x) = x != 0.0 ? Int(1) : Float64(x)
stringorfloatzero(x) = x != 0.0 ? "Hello" : Float64(x)
Expand All @@ -202,10 +203,10 @@ end
@test broadcast(stringorfloatzero, speye(4)) == sparse(broadcast(stringorfloatzero, eye(4)))
end

@testset "broadcast over combinations of scalars and sparse vectors/matrices" begin
N, M, p = 10, 12, 0.3
@testset "broadcast[!] over combinations of scalars and sparse vectors/matrices" begin
N, M, p = 10, 12, 0.5
elT = Float64
s = elT(2.0)
s = Float32(2.0)
V = sprand(elT, N, p)
A = sprand(elT, N, M, p)
fV, fA = Array(V), Array(A)
Expand Down Expand Up @@ -235,8 +236,29 @@ end
((s, spargsl..., s, s, spargsr...), (s, dargsl..., s, s, dargsr...)),
((spargsl..., s, s, spargsr..., s), (dargsl..., s, s, dargsr..., s)),
((spargsl..., s, s, s, spargsr...), (dargsl..., s, s, s, dargsr...)), )
# test broadcast entry point
@test broadcast(*, sparseargs...) == sparse(broadcast(*, denseargs...))
@test isa(@inferred(broadcast(*, sparseargs...)), SparseMatrixCSC{elT})
# test broadcast! entry point
fX = broadcast(*, sparseargs...); X = sparse(fX)
@test broadcast!(*, X, sparseargs...) == sparse(broadcast!(*, fX, denseargs...))
@test isa(@inferred(broadcast!(*, X, sparseargs...)), SparseMatrixCSC{elT})
X = sparse(fX) # reset / warmup for @allocated test
@test_broken (@allocated broadcast!(*, X, sparseargs...)) == 0
# This test (and the analog below) fails for three reasons:
# (1) In all cases, generating the closures that capture the scalar arguments
# results in allocation, not sure why.
# (2) In some cases, though _broadcast_eltype (which wraps _return_type)
# consistently provides the correct result eltype when passed the closure
# that incorporates the scalar arguments to broadcast (and, with #19667,
# is inferable, so the overall return type from broadcast is inferred),
# in some cases inference seems unable to determine the return type of
# direct calls to that closure. This issue causes variables in both the
# broadcast[!] entry points (fofzeros = f(_zeros_eltypes(args...)...)) and
# the driver routines (Cx in _map_zeropres! and _broadcast_zeropres!) to have
# inferred type Any, resulting in allocation and lackluster performance.
# (3) The sparseargs... splat in the call above allocates a bit, but of course
# that issue is negligible and perhaps could be accounted for in the test.
end
end
# test combinations at the limit of inference (eight arguments net)
Expand All @@ -248,8 +270,16 @@ end
((s, V, A, s, V, A, s, A), (s, fV, fA, s, fV, fA, s, fA)), # three scalars, five sparse vectors/matrices
((V, A, V, s, A, V, A, s), (fV, fA, fV, s, fA, fV, fA, s)), # two scalars, six sparse vectors/matrices
((V, A, V, A, s, V, A, V), (fV, fA, fV, fA, s, fV, fA, fV)) ) # one scalar, seven sparse vectors/matrices
# test broadcast entry point
@test broadcast(*, sparseargs...) == sparse(broadcast(*, denseargs...))
@test isa(@inferred(broadcast(*, sparseargs...)), SparseMatrixCSC{elT})
# test broadcast! entry point
fX = broadcast(*, sparseargs...); X = sparse(fX)
@test broadcast!(*, X, sparseargs...) == sparse(broadcast!(*, fX, denseargs...))
@test isa(@inferred(broadcast!(*, X, sparseargs...)), SparseMatrixCSC{elT})
X = sparse(fX) # reset / warmup for @allocated test
@test_broken (@allocated broadcast!(*, X, sparseargs...)) == 0
# please see the note a few lines above re. this @test_broken
end
end

Expand Down

0 comments on commit 3d9fc3c

Please sign in to comment.