From 9bdba0bca555407d676c65a592ee775cd28fa984 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20A=2E=20S=2E=20Silva?= Date: Tue, 2 Nov 2021 12:37:05 -0300 Subject: [PATCH 1/5] clean commit with approved changes --- src/transforms/remainder.jl | 60 ++++++++++++++++++++++++++++++++----- test/transforms.jl | 13 ++++++++ 2 files changed, 66 insertions(+), 7 deletions(-) diff --git a/src/transforms/remainder.jl b/src/transforms/remainder.jl index 267b7e1..daa9451 100644 --- a/src/transforms/remainder.jl +++ b/src/transforms/remainder.jl @@ -10,18 +10,64 @@ and returns a new table with an additional column containing the remainder value `xₙ₊₁ = total .- (x₁ + x₂ + ⋯ + xₙ)` If the `total` value is not specified, then default to the maximum sum across rows. """ -struct Remainder <: Transform end +struct Remainder <: Transform + total::Union{Float64,Nothing} +end + +Remainder() = Remainder(nothing) isrevertible(::Type{Remainder}) = true -function apply(::Remainder, table) - # TODO +assertions(::Type{Remainder}) = [TT.assert_continuous] + +function _cache(transform::Remainder, table) + # design matrix + X = Tables.matrix(table) + + # find total across rows + if !isnothing(transform.total) + transform.total + else + maximum(sum(X, dims=2)) + end end -function revert(::Remainder, newtable, cache) - # TODO +function _apply(transform::Remainder, table, cache) + # basic checks + for assertion in assertions(transform) + assertion(table) + end + + # design matrix + X = Tables.matrix(table) + + # retrieve the total + total = cache + + # make sure that the total passed is geater than or equal to sums across the rows + @assert all(x -> x ≤ total, sum(X, dims=2)) "the sum across rows must be less than total" + + # original column names + names = Tables.columnnames(table) + + # create a column with the remainder + S = sum(X, dims=2) + Z = [X (total .- S)] + + # table with the new column + rname = Symbol("total_minus_" * join(string.(names))) + names = (names..., rname) + 𝒯 = (; zip(names, eachcol(Z))...) + newtable = 𝒯 |> Tables.materializer(table) + + newtable, total end -function reapply(::Remainder, table, cache) - # TODO +apply(transform::Remainder, table) = _apply(transform, table, _cache(transform, table)) + +function revert(::Remainder, newtable, cache) + names = Tables.columnnames(newtable) + TT.Reject(last(names))(newtable) end + +reapply(transform::Remainder, table, cache) = _apply(transform, table, cache) \ No newline at end of file diff --git a/test/transforms.jl b/test/transforms.jl index c523bec..ee58bb5 100644 --- a/test/transforms.jl +++ b/test/transforms.jl @@ -36,4 +36,17 @@ tcls = revert(Closure(), n, c) @test Tables.matrix(n) ≈ [0.2 0.4 0.4; 0.66 0.22 0.12; 0.00 0.02 0.98;] @test Tables.matrix(tcls) ≈ Tables.matrix(t) + + # Tests for Remainder + t = (a=Float64[2,66,0], b=Float64[4,22,2], c=Float64[4,12,98]) + n, c = apply(Remainder(), t) + trem = revert(Remainder(), n, c) + @test Tables.matrix(n)[:, 1:end-1] == Tables.matrix(t) + @test all(x -> 0 ≤ x ≤ c, Tables.matrix(n)[:, end]) + @test n |> Tables.columnnames == (:a, :b, :c, :total_minus_abc) + @test trem |> Tables.columnnames == (:a, :b, :c) + + t = (a=Float64[1,10,0], b=Float64[1,5,0], c=Float64[4,2,1]) + n, c = reapply(Remainder(), t, c) + @test all(x -> 0 ≤ x ≤ c, Tables.matrix(n)[:, end]) end From 85b6c313ef666f84b660ce15dfa8f7e7a5a99c57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=BAlio=20Hoffimann?= Date: Tue, 2 Nov 2021 12:42:03 -0300 Subject: [PATCH 2/5] Update src/transforms/remainder.jl --- src/transforms/remainder.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transforms/remainder.jl b/src/transforms/remainder.jl index daa9451..b25a3d1 100644 --- a/src/transforms/remainder.jl +++ b/src/transforms/remainder.jl @@ -44,8 +44,8 @@ function _apply(transform::Remainder, table, cache) # retrieve the total total = cache - # make sure that the total passed is geater than or equal to sums across the rows - @assert all(x -> x ≤ total, sum(X, dims=2)) "the sum across rows must be less than total" + # make sure that the total is valid + @assert all(x -> x ≤ total, sum(X, dims=2)) "the sum for each row must be less than total" # original column names names = Tables.columnnames(table) From 5dd434f8a73568078a425cef13884c2ef081199c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=BAlio=20Hoffimann?= Date: Tue, 2 Nov 2021 12:42:53 -0300 Subject: [PATCH 3/5] Update src/transforms/remainder.jl --- src/transforms/remainder.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transforms/remainder.jl b/src/transforms/remainder.jl index b25a3d1..7ae76e3 100644 --- a/src/transforms/remainder.jl +++ b/src/transforms/remainder.jl @@ -55,7 +55,7 @@ function _apply(transform::Remainder, table, cache) Z = [X (total .- S)] # table with the new column - rname = Symbol("total_minus_" * join(string.(names))) + rname = Symbol("total_minus_", join(string.(names))) names = (names..., rname) 𝒯 = (; zip(names, eachcol(Z))...) newtable = 𝒯 |> Tables.materializer(table) From 00e135825a6e7b1f17685c90b88b76124e5d8bd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=BAlio=20Hoffimann?= Date: Tue, 2 Nov 2021 12:44:58 -0300 Subject: [PATCH 4/5] Update test/transforms.jl --- test/transforms.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/transforms.jl b/test/transforms.jl index ee58bb5..723783c 100644 --- a/test/transforms.jl +++ b/test/transforms.jl @@ -41,8 +41,10 @@ t = (a=Float64[2,66,0], b=Float64[4,22,2], c=Float64[4,12,98]) n, c = apply(Remainder(), t) trem = revert(Remainder(), n, c) - @test Tables.matrix(n)[:, 1:end-1] == Tables.matrix(t) - @test all(x -> 0 ≤ x ≤ c, Tables.matrix(n)[:, end]) + Xt = Tables.matrix(t) + Xn = Tables.matrix(n) + @test Xn[:, 1:end-1] == Xt + @test all(x -> 0 ≤ x ≤ c, Xn[:, end]) @test n |> Tables.columnnames == (:a, :b, :c, :total_minus_abc) @test trem |> Tables.columnnames == (:a, :b, :c) From 9227c328b3b8233dc696af044a03c0170724d23b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=BAlio=20Hoffimann?= Date: Tue, 2 Nov 2021 12:45:45 -0300 Subject: [PATCH 5/5] Update test/transforms.jl --- test/transforms.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/transforms.jl b/test/transforms.jl index 723783c..c6c3026 100644 --- a/test/transforms.jl +++ b/test/transforms.jl @@ -50,5 +50,6 @@ t = (a=Float64[1,10,0], b=Float64[1,5,0], c=Float64[4,2,1]) n, c = reapply(Remainder(), t, c) - @test all(x -> 0 ≤ x ≤ c, Tables.matrix(n)[:, end]) + Xn = Tables.matrix(n) + @test all(x -> 0 ≤ x ≤ c, Xn[:, end]) end