diff --git a/src/transforms/remainder.jl b/src/transforms/remainder.jl index 267b7e1..7ae76e3 100644 --- a/src/transforms/remainder.jl +++ b/src/transforms/remainder.jl @@ -10,18 +10,64 @@ and returns a new table with an additional column containing the remainder value `xₙ₊₁ = total .- (x₁ + x₂ + ⋯ + xₙ)` If the `total` value is not specified, then default to the maximum sum across rows. """ -struct Remainder <: Transform end +struct Remainder <: Transform + total::Union{Float64,Nothing} +end + +Remainder() = Remainder(nothing) isrevertible(::Type{Remainder}) = true -function apply(::Remainder, table) - # TODO +assertions(::Type{Remainder}) = [TT.assert_continuous] + +function _cache(transform::Remainder, table) + # design matrix + X = Tables.matrix(table) + + # find total across rows + if !isnothing(transform.total) + transform.total + else + maximum(sum(X, dims=2)) + end end -function revert(::Remainder, newtable, cache) - # TODO +function _apply(transform::Remainder, table, cache) + # basic checks + for assertion in assertions(transform) + assertion(table) + end + + # design matrix + X = Tables.matrix(table) + + # retrieve the total + total = cache + + # make sure that the total is valid + @assert all(x -> x ≤ total, sum(X, dims=2)) "the sum for each row must be less than total" + + # original column names + names = Tables.columnnames(table) + + # create a column with the remainder + S = sum(X, dims=2) + Z = [X (total .- S)] + + # table with the new column + rname = Symbol("total_minus_", join(string.(names))) + names = (names..., rname) + 𝒯 = (; zip(names, eachcol(Z))...) + newtable = 𝒯 |> Tables.materializer(table) + + newtable, total end -function reapply(::Remainder, table, cache) - # TODO +apply(transform::Remainder, table) = _apply(transform, table, _cache(transform, table)) + +function revert(::Remainder, newtable, cache) + names = Tables.columnnames(newtable) + TT.Reject(last(names))(newtable) end + +reapply(transform::Remainder, table, cache) = _apply(transform, table, cache) \ No newline at end of file diff --git a/test/transforms.jl b/test/transforms.jl index c523bec..c6c3026 100644 --- a/test/transforms.jl +++ b/test/transforms.jl @@ -36,4 +36,20 @@ tcls = revert(Closure(), n, c) @test Tables.matrix(n) ≈ [0.2 0.4 0.4; 0.66 0.22 0.12; 0.00 0.02 0.98;] @test Tables.matrix(tcls) ≈ Tables.matrix(t) + + # Tests for Remainder + t = (a=Float64[2,66,0], b=Float64[4,22,2], c=Float64[4,12,98]) + n, c = apply(Remainder(), t) + trem = revert(Remainder(), n, c) + Xt = Tables.matrix(t) + Xn = Tables.matrix(n) + @test Xn[:, 1:end-1] == Xt + @test all(x -> 0 ≤ x ≤ c, Xn[:, end]) + @test n |> Tables.columnnames == (:a, :b, :c, :total_minus_abc) + @test trem |> Tables.columnnames == (:a, :b, :c) + + t = (a=Float64[1,10,0], b=Float64[1,5,0], c=Float64[4,2,1]) + n, c = reapply(Remainder(), t, c) + Xn = Tables.matrix(n) + @test all(x -> 0 ≤ x ≤ c, Xn[:, end]) end