diff --git a/src/arithmetic.jl b/src/arithmetic.jl index b2f078e..b78bf49 100644 --- a/src/arithmetic.jl +++ b/src/arithmetic.jl @@ -133,3 +133,8 @@ function unsafe_push!(a::SparseArrays.SparseVector, k, v) a[k] = MA.add!!(a[k], v) return a end + +function unsafe_push!(a::Vector, k, v) + a[k] = MA.add!!(a[k], v) + return a +end diff --git a/src/mtables.jl b/src/mtables.jl index 20b5bee..f106b77 100644 --- a/src/mtables.jl +++ b/src/mtables.jl @@ -95,23 +95,6 @@ end _key(_, k) = k _key(mstr::MTable, k) = mstr[k] -function MA.operate!( - ms::UnsafeAddMul{<:MTable}, - res::AbstractCoefficients, - v::AbstractCoefficients, - w::AbstractCoefficients, -) - for (kv, a) in nonzero_pairs(v) - for (kw, b) in nonzero_pairs(w) - c = ms.structure(kv, kw) - for (k, v) in nonzero_pairs(c) - res[ms.structure[k]] += v * a * b - end - end - end - return res -end - function MA.operate!( ms::UnsafeAddMul{<:MTable}, res::AbstractSparseVector, @@ -138,20 +121,3 @@ function MA.operate!( res .+= sparsevec(idcs, vals, length(res)) return res end - -function MA.operate!( - ms::UnsafeAddMul{<:MTable}, - res::AbstractVector, - v::AbstractVector, - w::AbstractVector, -) - for (kv, a) in nonzero_pairs(v) - for (kw, b) in nonzero_pairs(w) - c = ms.structure(kv, kw) - for (k, v) in nonzero_pairs(c) - res[ms.structure[k]] += v * a * b - end - end - end - return res -end diff --git a/test/test_example_acoeffs.jl b/test/test_example_acoeffs.jl index 2c9272f..84d2857 100644 --- a/test/test_example_acoeffs.jl +++ b/test/test_example_acoeffs.jl @@ -25,3 +25,7 @@ end # the default arithmetic implementation uses this access Base.getindex(ac::ACoeffs, idx) = ac.vals[idx] Base.setindex!(ac::ACoeffs, val, idx) = ac.vals[idx] = val +function SA.unsafe_push!(ac::ACoeffs, idx, val) + ac.vals[idx] = MA.add!!(ac.vals[idx], val) + return ac +end