Skip to content

Commit

Permalink
add significance_test function
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed Jan 31, 2024
1 parent c1e044c commit 31fa3ef
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 32 deletions.
6 changes: 2 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TMLE"
uuid = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf"
authors = ["Olivier Labayle"]
version = "0.14.0"
version = "0.14.1"

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
Expand All @@ -19,7 +19,6 @@ MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377"
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down Expand Up @@ -51,12 +50,11 @@ MLJModels = "0.15, 0.16"
MetaGraphsNext = "0.7"
Missings = "1.0"
PrecompileTools = "1.1.1"
PrettyTables = "2.2"
SplitApplyCombine = "1.2.2"
TableOperations = "1.2"
Tables = "1.6"
YAML = "0.4.9"
Zygote = "0.6.69"
SplitApplyCombine = "1.2.2"
julia = "1.6, 1.7, 1"

[extras]
Expand Down
4 changes: 2 additions & 2 deletions src/TMLE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ using Distributions
using Zygote
using LogExpFunctions
using PrecompileTools
using PrettyTables
using Random
import AbstractDifferentiation as AD
using Graphs
Expand All @@ -32,7 +31,8 @@ export AVAILABLE_ESTIMANDS
export factorialATE, factorialIATE
export TMLEE, OSE, NAIVE
export ComposedEstimand
export var, estimate, OneSampleTTest, OneSampleZTest, OneSampleHotellingT2Test,pvalue, confint, emptyIC
export var, estimate, pvalue, confint, emptyIC
export significance_test, OneSampleTTest, OneSampleZTest, OneSampleHotellingT2Test
export compose
export TreatmentTransformer, with_encoder, encoder
export BackdoorAdjustment, identify
Expand Down
11 changes: 4 additions & 7 deletions src/counterfactual_mean_based/estimates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,6 @@ end

emptyIC(estimate; pval_threshold=nothing) = emptyIC(estimate, pval_threshold)


function Base.show(io::IO, ::MIME"text/plain", est::EICEstimate)
testresult = OneSampleTTest(est)
data = [estimate(est) confint(testresult) pvalue(testresult);]
pretty_table(io, data;header=["Estimate", "95% Confidence Interval", "P-value"])
end

"""
Distributions.estimate(r::EICEstimate)
Expand Down Expand Up @@ -104,3 +97,7 @@ Performs a T test on the EICEstimate.
HypothesisTests.OneSampleTTest(est::EICEstimate, Ψ₀=0) =
OneSampleTTest(est.estimate, est.std, est.n, Ψ₀)

significance_test(estimate::EICEstimate, Ψ₀=0) = OneSampleTTest(estimate, Ψ₀)

Base.show(io::IO, mime::MIME"text/plain", est::Union{EICEstimate, ComposedEstimand}) =
show(io, mime, significance_test(est))
21 changes: 9 additions & 12 deletions src/estimates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,18 +109,6 @@ to_matrix(x) = reduce(hcat, x)
ComposedEstimate(;estimand, estimates, estimate, cov, n) =
ComposedEstimate(estimand, Tuple(estimates), collect(estimate), to_matrix(cov), n)


function Base.show(io::IO, ::MIME"text/plain", est::ComposedEstimate)
if length(est.cov) !== 1
println(io, string("Estimate: ", estimate(est), "\nVariance: \n", var(est)))
else
testresult = OneSampleTTest(est)
data = [estimate(est) confint(testresult) pvalue(testresult);]
headers = ["Estimate", "95% Confidence Interval", "P-value"]
pretty_table(io, data;header=headers)
end
end

"""
Distributions.estimate(r::ComposedEstimate)
Expand Down Expand Up @@ -171,6 +159,15 @@ function HypothesisTests.OneSampleZTest(estimate::ComposedEstimate, Ψ₀=0)
return OneSampleZTest(estimate.estimate[1], sqrt(estimate.cov[1]), estimate.n, Ψ₀)
end

function significance_test(estimate::ComposedEstimate, Ψ₀=zeros(size(estimate.estimate, 1)))
if length(estimate.estimate) == 1
Ψ₀ = Ψ₀ isa AbstractArray ? first(Ψ₀) : Ψ₀
return OneSampleTTest(estimate, Ψ₀)
else
return OneSampleHotellingT2Test(estimate, Ψ₀)
end
end

function emptyIC(estimate::ComposedEstimate, pval_threshold)
emptied_estimates = Tuple(emptyIC(e, pval_threshold) for e in estimate.estimates)
ComposedEstimate(estimate.estimand, emptied_estimates, estimate.estimate, estimate.cov, estimate.n)
Expand Down
4 changes: 2 additions & 2 deletions test/composition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ end
ose = OSE(models=TMLE.default_models(G=LogisticClassifier(), Q_continuous=LinearRegressor()))
jointEstimate, _ = ose(jointIATE, dataset, verbosity=0)

testres = OneSampleHotellingT2Test(jointEstimate)
testres = significance_test(jointEstimate)
@test testres. jointEstimate.estimate
@test pvalue(testres) < 1e-10

Expand All @@ -213,7 +213,7 @@ end
maybe_emptied_estimate = TMLE.emptyIC(jointEstimate, pval_threshold=pval_threshold)
n_empty = 0
for i in 1:3
pval = pvalue(OneSampleTTest(jointEstimate.estimates[i]))
pval = pvalue(significance_test(jointEstimate.estimates[i]))
maybe_emptied_IC = maybe_emptied_estimate.estimates[i].IC
if pval > pval_threshold
@test maybe_emptied_IC == []
Expand Down
2 changes: 1 addition & 1 deletion test/counterfactual_mean_based/non_regression_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using YAML

function regression_tests(tmle_result)
@test estimate(tmle_result) -0.185533 atol = 1e-6
l, u = confint(OneSampleTTest(tmle_result))
l, u = confint(significance_test(tmle_result))
@test l -0.279246 atol = 1e-6
@test u -0.091821 atol = 1e-6
@test OneSampleZTest(tmle_result) isa OneSampleZTest
Expand Down
5 changes: 1 addition & 4 deletions test/helper_fns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,7 @@ at the given confidence level: here 0.05
"""
function test_coverage(result::TMLE.EICEstimate, Ψ₀)
# TMLE
lb, ub = confint(OneSampleTTest(result))
@test lb Ψ₀ ub
# OneStep
lb, ub = confint(OneSampleZTest(result))
lb, ub = confint(significance_test(result))
@test lb Ψ₀ ub
end

Expand Down

0 comments on commit 31fa3ef

Please sign in to comment.