Skip to content

Commit

Permalink
handle A ~ B - 1 and add tests (#1086)
Browse files Browse the repository at this point in the history
* handle -1 and add tests

* replace `import Base.==` with `Base.:(==)`

* typo and error test
  • Loading branch information
kleinschmidt authored and ararslan committed Oct 3, 2016
1 parent 1658c35 commit 400da84
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 15 deletions.
33 changes: 21 additions & 12 deletions src/statsmodels/formula.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ type Terms
intercept::Bool # is there an intercept column in the model matrix?
end

Base.:(==)(t1::Terms, t2::Terms) = all(getfield(t1, f)==getfield(t2, f) for f in fieldnames(t1))

type ModelFrame
df::AbstractDataFrame
terms::Terms
Expand Down Expand Up @@ -85,19 +87,26 @@ function dospecials(ex::Expr)
if !(a1 in specials) return ex end
excp = copy(ex)
excp.args = vcat(a1,map(dospecials, ex.args[2:end]))
if a1 != :* return excp end
aa = excp.args
a2 = aa[2]
a3 = aa[3]
if length(aa) > 3
excp.args = vcat(a1, aa[3:end])
a3 = dospecials(excp)
if a1 == :-
a2, a3 = excp.args[2:3]
a3 == 1 || error("invalid expression $ex; subtraction only supported for -1")
return :($a2 + -1)
elseif a1 == :*
aa = excp.args
a2 = aa[2]
a3 = aa[3]
if length(aa) > 3
excp.args = vcat(a1, aa[3:end])
a3 = dospecials(excp)
end
## this order of expansion gives the R-style ordering of interaction
## terms (after sorting in increasing interaction order) for higher-
## order interaction terms (e.g. x1 * x2 * x3 should expand to x1 +
## x2 + x3 + x1&x2 + x1&x3 + x2&x3 + x1&x2&x3)
:($a2 + $a2 & $a3 + $a3)
else
excp
end
## this order of expansion gives the R-style ordering of interaction
## terms (after sorting in increasing interaction order) for higher-
## order interaction terms (e.g. x1 * x2 * x3 should expand to x1 +
## x2 + x3 + x1&x2 + x1&x3 + x2&x3 + x1&x2&x3)
:($a2 + $a2 & $a3 + $a3)
end
dospecials(a::Any) = a

Expand Down
7 changes: 4 additions & 3 deletions test/formula.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@ module TestFormula
@test t.intercept == false
@test t.terms == [:x1, :x2]

t = Terms(y ~ -1 + x1 + x2)
@test t.intercept == false
@test t.terms == [:x1, :x2]
@test t == Terms(y ~ -1 + x1 + x2) == Terms(y ~ x1 - 1 + x2) == Terms(y ~ x1 + x2 -1)

## can't subtract terms other than 1
@test_throws ErrorException Terms(y ~ x1 - x2)

t = Terms(y ~ x1 & x2)
@test t.terms == [:(x1 & x2)]
Expand Down

0 comments on commit 400da84

Please sign in to comment.