From a85e3eb8f7560cb40035be0b452fea042b80c5e5 Mon Sep 17 00:00:00 2001 From: Lars Hellemo Date: Mon, 4 Apr 2022 11:04:43 +0200 Subject: [PATCH 1/8] Add codecov badge --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 37c127e..0cf9231 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # SparseVariables.jl +[![codecov](https://codecov.io/gh/hellemo/SparseVariables.jl/branch/main/graph/badge.svg?token=2LXGVU04YS)](https://codecov.io/gh/hellemo/SparseVariables.jl) + This package contains routines for improved and easier handling of sparse data and sparse arrays of optimizaton variables in JuMP. From 18bae32854d126e15eb43af491faa214dc753793 Mon Sep 17 00:00:00 2001 From: Lars Hellemo Date: Mon, 4 Apr 2022 13:14:18 +0200 Subject: [PATCH 2/8] More general fix for integer overflow --- src/dictionaries.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dictionaries.jl b/src/dictionaries.jl index 2df3d2d..e77d22d 100644 --- a/src/dictionaries.jl +++ b/src/dictionaries.jl @@ -127,7 +127,7 @@ Return integer encoding the `permutation` (as base N) # encode_permutation(permutation) = sum((permutation .- 1) .* _base_factors(Val(length(permutation)))) @generated function _encode_permutation(permutation::NTuple{N,M}) where {N,M} s = :(0) - if N > 15 # Int64 overflows at N=16 + if big(N+1)^(N+1) > typemax(Int) # Int64 overflows at N=16 N = big(N) end for i = 1:N From 6f713b2d753f9b62541d3b9b000e6961aa2eb559 Mon Sep 17 00:00:00 2001 From: Lars Hellemo Date: Mon, 4 Apr 2022 13:18:49 +0200 Subject: [PATCH 3/8] Revert Readme update --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index 0cf9231..37c127e 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,5 @@ # SparseVariables.jl -[![codecov](https://codecov.io/gh/hellemo/SparseVariables.jl/branch/main/graph/badge.svg?token=2LXGVU04YS)](https://codecov.io/gh/hellemo/SparseVariables.jl) - This package contains routines for improved and easier handling of sparse data and sparse arrays of optimizaton variables in JuMP. From 6d39e7b9e879d8e54d6d4c52dd1ad19938294ae9 Mon Sep 17 00:00:00 2001 From: Lars Hellemo Date: Mon, 4 Apr 2022 13:26:55 +0200 Subject: [PATCH 4/8] Formatting fixes --- src/dictionaries.jl | 104 +++++++++++++++++++++++--------------------- 1 file changed, 54 insertions(+), 50 deletions(-) diff --git a/src/dictionaries.jl b/src/dictionaries.jl index e77d22d..4581039 100644 --- a/src/dictionaries.jl +++ b/src/dictionaries.jl @@ -1,5 +1,5 @@ function variable_name(var::String, index) - return var *"[" * join(index,", ") * "]" + return var * "[" * join(index, ", ") * "]" end """ @@ -8,29 +8,29 @@ end Return function to use for filtering depending on the type and value of `c` to apply at position `pos` """ -make_filter_fun(c, pos) = x->x[pos] == c -make_filter_fun(c::Base.Fix2, pos) = x->c(x[pos]) -make_filter_fun(c::Function, pos) = x->c(x[pos]) -make_filter_fun(c::Colon, pos) = x->true +make_filter_fun(c, pos) = x -> x[pos] == c +make_filter_fun(c::Base.Fix2, pos) = x -> c(x[pos]) +make_filter_fun(c::Function, pos) = x -> c(x[pos]) +make_filter_fun(c::Colon, pos) = x -> true -make_filter_fun(c) = x->x==c -make_filter_fun(c::Base.Fix2) = x->c(x) -make_filter_fun(c::Function) = x->c(x) -make_filter_fun(c::Colon) = x->true -make_filter_fun(c::UnitRange) = x->(x ≥ c.start && x ≤ c.stop) +make_filter_fun(c) = x -> x == c +make_filter_fun(c::Base.Fix2) = x -> c(x) +make_filter_fun(c::Function) = x -> c(x) +make_filter_fun(c::Colon) = x -> true +make_filter_fun(c::UnitRange) = x -> (x ≥ c.start && x ≤ c.stop) """ recursive_filter(fs, data) Filter `data` recursively with functions `fs` """ function recursive_filter(fs, data) - (f, rest) = Iterators.peel(fs) - if isempty(rest) - return filter(f, data) - else - return recursive_filter(rest, filter(f, data)) - end + (f, rest) = Iterators.peel(fs) + if isempty(rest) + return filter(f, data) + else + return recursive_filter(rest, filter(f, data)) + end end @@ -41,7 +41,7 @@ Return functions to be used for filtering from a tuple following the format supported by `make_filter_fun` """ function indices_fun(some_tuple) - (make_filter_fun(v, pos) for (pos, v) in enumerate(some_tuple)) + (make_filter_fun(v, pos) for (pos, v) in enumerate(some_tuple)) end """ @@ -50,7 +50,7 @@ end Filter iterable data a by tuple `pattern` by row (slow) """ function _select_rowwise(a, pattern) - filter(x->reduce(&, f(x) for f in indices_fun(pattern)), a) + filter(x -> reduce(&, f(x) for f in indices_fun(pattern)), a) end @@ -60,7 +60,7 @@ end Filter iterable data a by tuple `pattern` by column (recursively) """ function _select_colwise(a, pattern) - recursive_filter(indices_fun(pattern), a) + recursive_filter(indices_fun(pattern), a) end """ @@ -70,7 +70,7 @@ Filter iterable data `a` by tuple `pattern` by row, using generated function for See more straight-forward implementations `_select_rowwise` and `_select_colwise` for reference. """ function _select_gen(a, pattern) - filter(x->_select_generated(pattern, x), a) + filter(x -> _select_generated(pattern, x), a) end """ @@ -91,7 +91,7 @@ _select_gen_perm(a, pat, (3,2,1)) # faster """ function _select_gen_perm(a, pattern, perm) p = Permutation(perm) - filter(x->_select_gen_permute(pattern, x, p), a) + filter(x -> _select_gen_permute(pattern, x, p), a) end """ @@ -101,7 +101,7 @@ Compose function from pattern `pat` to filter entire tuple at once, see `_select """ _select_generated(pat, x) = _select_generated(pat, x, Val(length(pat))) -@generated function _select_generated(pat,x,::Val{N}) where N +@generated function _select_generated(pat, x, ::Val{N}) where {N} ex = :(true) for i = 1:N ex = :($ex && make_filter_fun(pat[$i])(x[$i])) @@ -109,7 +109,8 @@ _select_generated(pat, x) = _select_generated(pat, x, Val(length(pat))) return :($ex) end -_select_gen_permute(pat,x,permutation) = _select_gen_permute(pat, x, permutation, Val(length(pat))) +_select_gen_permute(pat, x, permutation) = + _select_gen_permute(pat, x, permutation, Val(length(pat))) """ Permutation{N,K} @@ -117,7 +118,7 @@ Encode permutation as number of elements to permute `N` and number in sequence o to use for dispatch. """ struct Permutation{N,K} end -Permutation(t::Tuple) = Permutation{length(t), _encode_permutation(t)}() +Permutation(t::Tuple) = Permutation{length(t),_encode_permutation(t)}() """ _encode_permutation(permutation) @@ -125,14 +126,14 @@ Permutation(t::Tuple) = Permutation{length(t), _encode_permutation(t)}() Return integer encoding the `permutation` (as base N) """ # encode_permutation(permutation) = sum((permutation .- 1) .* _base_factors(Val(length(permutation)))) -@generated function _encode_permutation(permutation::NTuple{N,M}) where {N,M} +@generated function _encode_permutation(permutation::NTuple{N,M}) where {N,M} s = :(0) - if big(N+1)^(N+1) > typemax(Int) # Int64 overflows at N=16 + if big(N + 1)^(N + 1) > typemax(Int) # Int64 overflows at N=16 N = big(N) end for i = 1:N - bf = N^(N-i) - s = :($s + (permutation[$i]-1) * $bf) + bf = N^(N - i) + s = :($s + (permutation[$i] - 1) * $bf) end return s end @@ -149,25 +150,25 @@ function _decode_permutation(N, K) for i = 1:N p = div(tmp, base_factors[i]) tmp -= p * base_factors[i] - push!(perm, p+1) + push!(perm, p + 1) end return tuple(perm...) end -@generated function _base_factors(::Val{N}) where N +@generated function _base_factors(::Val{N}) where {N} base_factors = [] if N > 15 # Int64 overflows at N=16 N = big(N) end for i = N:-1:1 - push!(base_factors, N^(i-1)) + push!(base_factors, N^(i - 1)) end ex = :($base_factors) return ex end -@generated function _select_gen_permute(pat,x,::Permutation{N,K}) where {N,K} +@generated function _select_gen_permute(pat, x, ::Permutation{N,K}) where {N,K} ex = :(true) fs = [] for i = 1:N @@ -193,24 +194,27 @@ function select(dict, sh_pat::NamedTuple, names) end select(dict::Dictionary, indices) = getindices(dict, select(keys(dict), indices)) select(dict, f::Function) = filter(f, dict) -kselect(sa::SparseVarArray, sh_pat::NamedTuple) = select(keys(sa.data), sh_pat, get_index_names(sa)) -select(sa::SparseVarArray, sh_pat::NamedTuple) = Dictionaries.getindices(sa, kselect(sa, sh_pat)) +kselect(sa::SparseVarArray, sh_pat::NamedTuple) = + select(keys(sa.data), sh_pat, get_index_names(sa)) +select(sa::SparseVarArray, sh_pat::NamedTuple) = + Dictionaries.getindices(sa, kselect(sa, sh_pat)) -select_test(dict, indices, cache) = cache ? _select_cached(dict, indices) : _select_gen(keys(dict), indices) +select_test(dict, indices, cache) = + cache ? _select_cached(dict, indices) : _select_gen(keys(dict), indices) function _select_cached(sa, pat) - indices = Tuple(i for (i,v) in enumerate(pat) if v !== Colon()) + indices = Tuple(i for (i, v) in enumerate(pat) if v !== Colon()) vals = Tuple(v for v in pat if v !== Colon()) - + if !(indices in keys(sa.index_cache)) index = Dict() for v in keys(sa) - vred = Tuple(val for (i,val) in enumerate(v) if i in indices) + vred = Tuple(val for (i, val) in enumerate(v) if i in indices) if !(vred in keys(index)) index[vred] = [] end - push!(index[vred], v) + push!(index[vred], v) end sa.index_cache[indices] = index end @@ -219,8 +223,8 @@ end function permfromnames(names::NamedTuple, patnames) perm = (names[i] for i in patnames) - rest = setdiff((1:length(names)),perm) - return (perm...,rest...) + rest = setdiff((1:length(names)), perm) + return (perm..., rest...) end function expand_shorthand(sh_pat, names) @@ -229,7 +233,7 @@ function expand_shorthand(sh_pat, names) if haskey(sh_pat, n) push!(pat, sh_pat[n]) else - push!(pat, x->true) + push!(pat, x -> true) end end perm = permfromnames(names, propertynames(sh_pat)) @@ -242,10 +246,10 @@ Return false for functions, wildcards and ranges, true for all other types Works on types because it is used in generated function """ isfixed(t) = true -isfixed(::Type{T} where T<:Function) = false -isfixed(::Type{T} where T<:UnitRange) = false +isfixed(::Type{T} where {T<:Function}) = false +isfixed(::Type{T} where {T<:UnitRange}) = false iscolon(t) = false -iscolon(::Type{T} where T<:Colon) = true +iscolon(::Type{T} where {T<:Colon}) = true @generated function _getindex(sa::AbstractSparseArray{T,N}, tpl::Tuple) where {T,N} @@ -259,14 +263,14 @@ iscolon(::Type{T} where T<:Colon) = true end end end - + if lookup - return :( get(_data(sa), tpl, zero(T)) ) + return :(get(_data(sa), tpl, zero(T))) elseif !slice - return :( retval = select(_data(sa), tpl); length(retval)>0 ? retval : zero(T) ) + return :(retval = select(_data(sa), tpl); length(retval) > 0 ? retval : zero(T)) else # Return selection or zero if empty to avoid reduction of empty iterate - return :( retval = _select_var(sa, tpl); length(retval) > 0 ? retval : zero(T)) + return :(retval = _select_var(sa, tpl); length(retval) > 0 ? retval : zero(T)) end end -_select_var(sa, tpl) = getindices(_data(sa), select_test(sa, tpl, true)) \ No newline at end of file +_select_var(sa, tpl) = getindices(_data(sa), select_test(sa, tpl, true)) From 311f2d99d7b5b3c7d2b88a1ba750956523df6f38 Mon Sep 17 00:00:00 2001 From: Lars Hellemo Date: Mon, 4 Apr 2022 14:14:53 +0200 Subject: [PATCH 5/8] Try fixing overflow on x86 --- src/dictionaries.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dictionaries.jl b/src/dictionaries.jl index b114798..9b3b6cf 100644 --- a/src/dictionaries.jl +++ b/src/dictionaries.jl @@ -125,7 +125,7 @@ Return integer encoding the `permutation` (as base N) # encode_permutation(permutation) = sum((permutation .- 1) .* _base_factors(Val(length(permutation)))) @generated function _encode_permutation(permutation::NTuple{N,M}) where {N,M} s = :(0) - if big(N + 1)^(N + 1) > typemax(Int) # Int64 overflows at N=16 + if sum(big(n)^N for n in 1:N) > typemax(Int) # Int64 overflows at N=16 N = big(N) end for i in 1:N From 57badf48206e62bd24db85f570a89704d58370f9 Mon Sep 17 00:00:00 2001 From: Lars Hellemo Date: Mon, 4 Apr 2022 15:00:35 +0200 Subject: [PATCH 6/8] common function to check for overflow --- src/dictionaries.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/dictionaries.jl b/src/dictionaries.jl index 9b3b6cf..9402be0 100644 --- a/src/dictionaries.jl +++ b/src/dictionaries.jl @@ -125,7 +125,7 @@ Return integer encoding the `permutation` (as base N) # encode_permutation(permutation) = sum((permutation .- 1) .* _base_factors(Val(length(permutation)))) @generated function _encode_permutation(permutation::NTuple{N,M}) where {N,M} s = :(0) - if sum(big(n)^N for n in 1:N) > typemax(Int) # Int64 overflows at N=16 + if will_overflow(N) N = big(N) end for i in 1:N @@ -152,9 +152,11 @@ function _decode_permutation(N, K) return tuple(perm...) end +will_overflow(N) = sum(big(n)^N for n in 1:N) > typemax(Int) + @generated function _base_factors(::Val{N}) where {N} base_factors = [] - if N > 15 # Int64 overflows at N=16 + if will_overflow(N) N = big(N) end for i in N:-1:1 From f84dc984c6b8e4b5aaff292f5c4888692fc5d9f6 Mon Sep 17 00:00:00 2001 From: Lars Hellemo Date: Mon, 4 Apr 2022 15:04:38 +0200 Subject: [PATCH 7/8] formatting fix --- src/dictionaries.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dictionaries.jl b/src/dictionaries.jl index 9402be0..d0e6350 100644 --- a/src/dictionaries.jl +++ b/src/dictionaries.jl @@ -153,7 +153,7 @@ function _decode_permutation(N, K) end will_overflow(N) = sum(big(n)^N for n in 1:N) > typemax(Int) - + @generated function _base_factors(::Val{N}) where {N} base_factors = [] if will_overflow(N) From c721f8e6052d99efd23b64fa9e6948dbbc0e20d9 Mon Sep 17 00:00:00 2001 From: Lars Hellemo Date: Mon, 4 Apr 2022 15:14:15 +0200 Subject: [PATCH 8/8] More Int64 -> Int --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 73bb7f8..ac389b5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -91,7 +91,7 @@ end end @testset "SparseArray" begin - @test typeof(car_cost) == SV.SparseArray{Int64,2,Tuple{String,Int64}} + @test typeof(car_cost) == SV.SparseArray{Int,2,Tuple{String,Int}} @test length(car_cost) == 5 @test car_cost["bmw", 2001] == 200